From f0caea9026c413cbcc965740965fff9441526a60 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Tue, 21 Nov 2023 14:23:47 +0800
Subject: [PATCH 01/47] =?UTF-8?q?feat:=20=E9=92=88=E5=AF=B9=20OneBot=20?=
=?UTF-8?q?=E5=92=8C=20NoneBot=20=E7=9A=84=E6=B6=88=E6=81=AF=E5=85=BC?=
=?UTF-8?q?=E5=AE=B9=E5=B1=82=E5=92=8C=E6=8F=92=E4=BB=B6=E7=9A=84=E5=88=9D?=
=?UTF-8?q?=E6=AD=A5=E9=80=82=E9=85=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
cores/qqbot/core.py | 113 +++++++--------
model/command/adapter/nonebot/command_arg.py | 2 +
model/command/adapter/nonebot/common.py | 32 ++++
model/command/adapter/nonebot/driver.py | 11 ++
model/command/adapter/onebot/bot.py | 2 +
model/command/adapter/onebot/message.py | 2 +
model/command/adapter/onebot/message_event.py | 2 +
.../command/adapter/onebot/message_segment.py | 2 +
model/command/adapter/protocol_adapter.py | 137 ++++++++++++++++++
model/command/command.py | 14 +-
model/command/openai_official.py | 8 +-
model/command/rev_chatgpt.py | 4 +-
model/platform/qqchan.py | 7 +-
13 files changed, 269 insertions(+), 67 deletions(-)
create mode 100644 model/command/adapter/nonebot/command_arg.py
create mode 100644 model/command/adapter/nonebot/common.py
create mode 100644 model/command/adapter/nonebot/driver.py
create mode 100644 model/command/adapter/onebot/bot.py
create mode 100644 model/command/adapter/onebot/message.py
create mode 100644 model/command/adapter/onebot/message_event.py
create mode 100644 model/command/adapter/onebot/message_segment.py
create mode 100644 model/command/adapter/protocol_adapter.py
diff --git a/cores/qqbot/core.py b/cores/qqbot/core.py
index 262483622..95000a6b6 100644
--- a/cores/qqbot/core.py
+++ b/cores/qqbot/core.py
@@ -207,6 +207,61 @@ def initBot(cfg, prov):
if 'reply_prefix' in cfg:
_global_object.reply_prefix = cfg['reply_prefix']
+
+ gu.log("--------加载机器人平台--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+ thread_inst = None
+ admin_qq = cc.get('admin_qq', None)
+ admin_qqchan = cc.get('admin_qqchan', None)
+ if admin_qq == None:
+ gu.log("未设置管理者QQ号(管理者才能使用update/plugin等指令)", gu.LEVEL_WARNING)
+ admin_qq = input("请输入管理者QQ号(必须设置): ")
+ gu.log("管理者QQ号设置为: " + admin_qq, gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+ cc.put('admin_qq', admin_qq)
+ if admin_qqchan == None:
+ gu.log("未设置管理者QQ频道用户号(管理者才能使用update/plugin等指令)", gu.LEVEL_WARNING)
+ admin_qqchan = input("请输入管理者频道用户号(不是QQ号, 可以先回车跳过然后在频道发送指令!myid获取): ")
+ if admin_qqchan == "":
+ gu.log("跳过设置管理者频道用户号", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+ else:
+ gu.log("管理者频道用户号设置为: " + admin_qqchan, gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+ cc.put('admin_qqchan', admin_qqchan)
+
+ gu.log("管理者QQ: " + admin_qq, gu.LEVEL_INFO)
+ gu.log("管理者频道用户号: " + admin_qqchan, gu.LEVEL_INFO)
+ _global_object.admin_qq = admin_qq
+ _global_object.admin_qqchan = admin_qqchan
+
+ # GOCQ
+ global gocq_bot
+
+ if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
+ gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO)
+
+ global gocq_app, gocq_loop
+ gocq_loop = asyncio.new_event_loop()
+ gocq_bot = QQ(True, cc, gocq_loop)
+ thread_inst = threading.Thread(target=run_gocq_bot, args=(gocq_loop, gocq_bot, gocq_app), daemon=False)
+ thread_inst.start()
+ else:
+ gocq_bot = QQ(False)
+
+ _global_object.platform_qq = gocq_bot
+
+ gu.log("机器人部署教程: https://github.com/Soulter/QQChannelChatGPT/wiki/", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+ gu.log("如果有任何问题, 请在 https://github.com/Soulter/QQChannelChatGPT 上提交issue或加群322154837", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+ gu.log("请给 https://github.com/Soulter/QQChannelChatGPT 点个star!", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
+
+ # QQ频道
+ if 'qqbot' in cfg and cfg['qqbot']['enable']:
+ gu.log("- 启用QQ频道机器人 -", gu.LEVEL_INFO)
+ global qqchannel_bot, qqchan_loop
+ qqchannel_bot = QQChan()
+ qqchan_loop = asyncio.new_event_loop()
+ _global_object.platform_qqchan = qqchannel_bot
+ thread_inst = threading.Thread(target=run_qqchan_bot, args=(cfg, qqchan_loop, qqchannel_bot), daemon=False)
+ thread_inst.start()
+ # thread.join()
+
# 语言模型提供商
gu.log("--------加载语言模型--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
@@ -312,8 +367,6 @@ def initBot(cfg, prov):
nick_qq = tuple(nick_qq)
_global_object.nick = nick_qq
- thread_inst = None
-
gu.log("--------加载插件--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
# 加载插件
_command = Command(None, _global_object)
@@ -327,59 +380,6 @@ def initBot(cfg, prov):
llm_command_instance[NONE_LLM] = _command
chosen_provider = NONE_LLM
- gu.log("--------加载机器人平台--------", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
-
- admin_qq = cc.get('admin_qq', None)
- admin_qqchan = cc.get('admin_qqchan', None)
- if admin_qq == None:
- gu.log("未设置管理者QQ号(管理者才能使用update/plugin等指令)", gu.LEVEL_WARNING)
- admin_qq = input("请输入管理者QQ号(必须设置): ")
- gu.log("管理者QQ号设置为: " + admin_qq, gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
- cc.put('admin_qq', admin_qq)
- if admin_qqchan == None:
- gu.log("未设置管理者QQ频道用户号(管理者才能使用update/plugin等指令)", gu.LEVEL_WARNING)
- admin_qqchan = input("请输入管理者频道用户号(不是QQ号, 可以先回车跳过然后在频道发送指令!myid获取): ")
- if admin_qqchan == "":
- gu.log("跳过设置管理者频道用户号", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
- else:
- gu.log("管理者频道用户号设置为: " + admin_qqchan, gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
- cc.put('admin_qqchan', admin_qqchan)
-
- gu.log("管理者QQ: " + admin_qq, gu.LEVEL_INFO)
- gu.log("管理者频道用户号: " + admin_qqchan, gu.LEVEL_INFO)
- _global_object.admin_qq = admin_qq
- _global_object.admin_qqchan = admin_qqchan
-
- # GOCQ
- global gocq_bot
-
- if 'gocqbot' in cfg and cfg['gocqbot']['enable']:
- gu.log("- 启用QQ机器人 -", gu.LEVEL_INFO)
-
- global gocq_app, gocq_loop
- gocq_loop = asyncio.new_event_loop()
- gocq_bot = QQ(True, cc, gocq_loop)
- thread_inst = threading.Thread(target=run_gocq_bot, args=(gocq_loop, gocq_bot, gocq_app), daemon=False)
- thread_inst.start()
- else:
- gocq_bot = QQ(False)
-
- _global_object.platform_qq = gocq_bot
-
- gu.log("机器人部署教程: https://github.com/Soulter/QQChannelChatGPT/wiki/", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
- gu.log("如果有任何问题, 请在 https://github.com/Soulter/QQChannelChatGPT 上提交issue或加群322154837", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
- gu.log("请给 https://github.com/Soulter/QQChannelChatGPT 点个star!", gu.LEVEL_INFO, fg=gu.FG_COLORS['yellow'])
-
- # QQ频道
- if 'qqbot' in cfg and cfg['qqbot']['enable']:
- gu.log("- 启用QQ频道机器人 -", gu.LEVEL_INFO)
- global qqchannel_bot, qqchan_loop
- qqchannel_bot = QQChan()
- qqchan_loop = asyncio.new_event_loop()
- _global_object.platform_qqchan = qqchannel_bot
- thread_inst = threading.Thread(target=run_qqchan_bot, args=(cfg, qqchan_loop, qqchannel_bot), daemon=False)
- thread_inst.start()
- # thread.join()
if thread_inst == None:
input("[System-Error] 没有启用/成功启用任何机器人,程序退出")
@@ -606,7 +606,7 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
if session_id in gocq_bot.waiting and gocq_bot.waiting[session_id] == '':
gocq_bot.waiting[session_id] = qq_msg
return
- hit, command_result = llm_command_instance[chosen_provider].check_command(
+ hit, command_result = await llm_command_instance[chosen_provider].check_command(
qq_msg,
session_id,
role,
@@ -752,7 +752,6 @@ class botClient(botpy.Client):
# 收到频道消息
async def on_at_message_create(self, message: Message):
gu.log(str(message), gu.LEVEL_DEBUG, max_len=9999)
-
# 转换层
nakuru_guild_message = qqchannel_bot.gocq_compatible_receive(message)
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
diff --git a/model/command/adapter/nonebot/command_arg.py b/model/command/adapter/nonebot/command_arg.py
new file mode 100644
index 000000000..3f6e8a305
--- /dev/null
+++ b/model/command/adapter/nonebot/command_arg.py
@@ -0,0 +1,2 @@
+class CommandArg:
+ pass
\ No newline at end of file
diff --git a/model/command/adapter/nonebot/common.py b/model/command/adapter/nonebot/common.py
new file mode 100644
index 000000000..ac7532768
--- /dev/null
+++ b/model/command/adapter/nonebot/common.py
@@ -0,0 +1,32 @@
+import sys
+from types import ModuleType
+import asyncio
+from pyppeteer import launch
+
+
+async def template_to_pic(template_path, template_name, templates, pages, wait, type, quality, device_scale_factor):
+ browser = await launch()
+ page = await browser.newPage()
+ await page.setViewport(pages["viewport"])
+ await page.goto(pages["base_url"])
+ await asyncio.sleep(wait)
+ await page.evaluate('''(templates) => {
+ // 在页面中执行 JavaScript 代码,将数据注入到模板中
+ // 这里的示例代码仅供参考,具体需要根据实际情况修改
+ document.getElementById('css').innerText = templates.css;
+ document.getElementById('data').innerText = JSON.stringify(templates.data);
+ document.getElementById('detail').innerText = templates.detail;
+ }''', templates)
+ screenshot = await page.screenshot({
+ 'type': type,
+ 'quality': quality,
+ 'deviceScaleFactor': device_scale_factor
+ })
+ await browser.close()
+ return screenshot
+
+def require(module_str: str):
+ module = ModuleType(module_str)
+ sys.modules[module_str] = module
+ if module_str == 'nonebot_plugin_htmlrender':
+ module.template_to_pic = template_to_pic
diff --git a/model/command/adapter/nonebot/driver.py b/model/command/adapter/nonebot/driver.py
new file mode 100644
index 000000000..a184fbcaf
--- /dev/null
+++ b/model/command/adapter/nonebot/driver.py
@@ -0,0 +1,11 @@
+class Driver:
+ def __init__(self) -> None:
+ self.config = {}
+
+ def on_startup(self, func):
+ pass
+ def on_bot_connect(self, func):
+ pass
+
+def get_driver():
+ return Driver()
\ No newline at end of file
diff --git a/model/command/adapter/onebot/bot.py b/model/command/adapter/onebot/bot.py
new file mode 100644
index 000000000..19be4bcb3
--- /dev/null
+++ b/model/command/adapter/onebot/bot.py
@@ -0,0 +1,2 @@
+class Bot:
+ pass
\ No newline at end of file
diff --git a/model/command/adapter/onebot/message.py b/model/command/adapter/onebot/message.py
new file mode 100644
index 000000000..574b0ce16
--- /dev/null
+++ b/model/command/adapter/onebot/message.py
@@ -0,0 +1,2 @@
+class Message:
+ pass
\ No newline at end of file
diff --git a/model/command/adapter/onebot/message_event.py b/model/command/adapter/onebot/message_event.py
new file mode 100644
index 000000000..11712e075
--- /dev/null
+++ b/model/command/adapter/onebot/message_event.py
@@ -0,0 +1,2 @@
+class MessageEvent:
+ pass
\ No newline at end of file
diff --git a/model/command/adapter/onebot/message_segment.py b/model/command/adapter/onebot/message_segment.py
new file mode 100644
index 000000000..2905a1757
--- /dev/null
+++ b/model/command/adapter/onebot/message_segment.py
@@ -0,0 +1,2 @@
+class MessageSegment:
+ pass
\ No newline at end of file
diff --git a/model/command/adapter/protocol_adapter.py b/model/command/adapter/protocol_adapter.py
new file mode 100644
index 000000000..1f0f6030a
--- /dev/null
+++ b/model/command/adapter/protocol_adapter.py
@@ -0,0 +1,137 @@
+import sys
+from types import ModuleType
+import asyncio
+from pyppeteer import launch
+
+from model.platform.qqchan import QQChan
+
+from .nonebot.driver import Driver, get_driver
+from .onebot.message import Message
+from .onebot.message_event import MessageEvent
+from .onebot.message_segment import MessageSegment
+from .nonebot.command_arg import CommandArg
+from .onebot.bot import Bot
+
+from nakuru import (
+ GuildMessage,
+ GroupMessage,
+ FriendMessage
+)
+
+from typing import Union
+
+NONEBOT = "nonebot"
+
+class UnifiedBotCompatibleLayer():
+ def __init__(self, platform_qq_sdk: QQChan) -> None:
+ # 初始化兼容层
+ self.plugins: dict[str, CommandOper] = {}
+ self.platform_qq_sdk = platform_qq_sdk
+ self._nonebot()
+ self.load_plugins()
+
+ async def check_commands(self, message: str, message_obj: Union[GroupMessage, FriendMessage, GuildMessage]):
+ for k in self.plugins:
+ if message.startswith(k):
+ if self.plugins[k].framework_name == NONEBOT:
+ await self._nonebot_plugins_oper(message, message_obj, k)
+
+ async def _nonebot_plugins_oper(self, message: str, message_obj: Union[GroupMessage, FriendMessage, GuildMessage], plugin_name: str = None):
+ # bad implementation
+ # 高并发场景下,下面的代码是不安全的
+ while self.plugins[plugin_name].message_obj is not None:
+ await asyncio.sleep(1)
+ self.plugins[plugin_name].message_obj = message_obj
+ bot, event, arg = self._nonebot_adapter(message_obj)
+ await self.plugins[plugin_name].exec(bot, event, arg) # wrapper
+
+ def load_plugins(self):
+ import nonebot_plugin_gspanel.nonebot_plugin_gspanel
+
+ def _nonebot(self):
+ # 模拟 nonebot 模块
+ nonebot_module = ModuleType('nonebot')
+ sys.modules['nonebot'] = nonebot_module
+
+ nonebot_log_module = ModuleType('nonebot.log')
+ sys.modules['nonebot.log'] = nonebot_log_module
+
+ nonebot_adapter_module = ModuleType('nonebot.adapters')
+ sys.modules['nonebot.adapters'] = nonebot_adapter_module
+
+ nonebot_params_module = ModuleType('nonebot.params')
+ sys.modules['nonebot.params'] = nonebot_params_module
+
+ nonebot_drivers_module = ModuleType('nonebot.drivers')
+ sys.modules['nonebot.drivers'] = nonebot_drivers_module
+
+ nonebot_plugin_module = ModuleType('nonebot.plugin')
+ sys.modules['nonebot.plugin'] = nonebot_plugin_module
+
+ nonebot_adapter_onebot_v11_module = ModuleType('nonebot.adapters.onebot.v11')
+ sys.modules['nonebot.adapters.onebot.v11'] = nonebot_adapter_onebot_v11_module
+
+ nonebot_adapter_onebot_v11_event_module = ModuleType('nonebot.adapters.onebot.v11.event')
+ sys.modules['nonebot.adapters.onebot.v11.event'] = nonebot_adapter_onebot_v11_event_module
+
+ nonebot_adapter_onebot_v11_message_module = ModuleType('nonebot.adapters.onebot.v11.message')
+ sys.modules['nonebot.adapters.onebot.v11.message'] = nonebot_adapter_onebot_v11_message_module
+
+ nonebot_log_module.logger = lambda: None
+ nonebot_adapter_module.Message = Message
+ nonebot_params_module.CommandArg = CommandArg
+ on_command = wrap_on_command(self)
+ nonebot_plugin_module.on_command = on_command
+ nonebot_adapter_onebot_v11_module.Bot = Bot
+ nonebot_adapter_onebot_v11_event_module.MessageEvent = MessageEvent
+ nonebot_adapter_onebot_v11_message_module.MessageSegment = MessageSegment
+ nonebot_module.get_driver = get_driver
+ nonebot_module.require = require
+ nonebot_drivers_module.Driver = Driver
+
+ def _nonebot_adapter(self, message_obj):
+ bot = Bot()
+ event = MessageEvent()
+ arg = CommandArg()
+ # tododssss
+ return bot, event, arg
+
+
+class BaseBot():
+ def __init__(self, framework_name) -> None:
+ self.framework_name = framework_name
+
+class CommandOper(BaseBot):
+ '''
+ CommandOper for NoneBot
+ '''
+ def __init__(self, name, aliases=None, priority=1, block=False, _ubcl: UnifiedBotCompatibleLayer = None) -> None:
+ super().__init__("nonebot")
+ self.name = name
+ self.aliases = aliases
+ self.priority = priority
+ self.block = block
+ self.exec = None
+ self._ubcl = _ubcl
+ self.message_obj: Union[GroupMessage, FriendMessage, GuildMessage] = None
+ _ubcl.plugins[name] = self
+
+ def handle(self):
+ def decorator(func):
+ async def wrapper(bot: Bot, event: MessageEvent, arg: Message = CommandArg(), *args, **kwargs):
+ # 你可以在这里添加自定义的处理逻辑
+ print(f"Command {self.name} is executed.")
+ await func(bot, event, arg, *args, **kwargs)
+ self.exec = wrapper
+ return wrapper
+ return decorator
+
+ async def finish(self, msg, at_sender = True):
+ if self.message_obj is not None:
+ self._ubcl.platform_qq_sdk.send(self.message_obj, msg)
+ self.message_obj = None
+
+def wrap_on_command(_ubcl: UnifiedBotCompatibleLayer):
+ def on_command(name, aliases=None, priority=1, block=False):
+ return CommandOper(name, aliases, priority, block, _ubcl = _ubcl)
+ return on_command
diff --git a/model/command/command.py b/model/command/command.py
index b23c4ec05..b25f774b1 100644
--- a/model/command/command.py
+++ b/model/command/command.py
@@ -26,21 +26,29 @@ from PIL import Image as PILImage
from cores.qqbot.global_object import GlobalObject, AstrMessageEvent
from pip._internal import main as pipmain
+from .adapter.protocol_adapter import UnifiedBotCompatibleLayer
+import asyncio
+
PLATFORM_QQCHAN = 'qqchan'
PLATFORM_GOCQ = 'gocq'
# 指令功能的基类,通用的(不区分语言模型)的指令就在这实现
class Command:
- def __init__(self, provider: Provider, global_object: GlobalObject = None):
+ def __init__(self, provider: Provider, global_object: GlobalObject = None, unified_bot_compatible_layer: UnifiedBotCompatibleLayer = None):
self.provider = provider
self.global_object = global_object
+ self.unified_bot_compatible_layer = unified_bot_compatible_layer
- def check_command(self,
+
+ async def check_command(self,
message,
session_id: str,
role,
platform,
message_obj):
+ # UBCL
+ await self.unified_bot_compatible_layer.check_commands(message, message_obj)
+
# 插件
cached_plugins = self.global_object.cached_plugins
ame = AstrMessageEvent(
@@ -70,10 +78,8 @@ class Command:
if self.command_start_with(message, "nick"):
return True, self.set_nick(message, platform, role)
-
if self.command_start_with(message, "plugin"):
return True, self.plugin_oper(message, role, cached_plugins, platform)
-
if self.command_start_with(message, "myid") or self.command_start_with(message, "!myid"):
return True, self.get_my_id(message_obj)
if self.command_start_with(message, "nconf") or self.command_start_with(message, "newconf"):
diff --git a/model/command/openai_official.py b/model/command/openai_official.py
index aa5247e6f..5974f12e1 100644
--- a/model/command/openai_official.py
+++ b/model/command/openai_official.py
@@ -5,6 +5,7 @@ from cores.qqbot.personality import personalities
from model.platform.qq import QQ
from util import general_utils as gu
from cores.qqbot.global_object import GlobalObject
+from .adapter.protocol_adapter import UnifiedBotCompatibleLayer
class CommandOpenAIOfficial(Command):
def __init__(self, provider: ProviderOpenAIOfficial, global_object: GlobalObject):
@@ -12,16 +13,17 @@ class CommandOpenAIOfficial(Command):
self.cached_plugins = {}
self.global_object = global_object
self.personality_str = ""
- super().__init__(provider, global_object)
+ self.unified_bot_compatible_layer = UnifiedBotCompatibleLayer(self.global_object.platform_qqchan)
+ super().__init__(provider, global_object, self.unified_bot_compatible_layer)
- def check_command(self,
+ async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
self.platform = platform
- hit, res = super().check_command(
+ hit, res = await super().check_command(
message,
session_id,
role,
diff --git a/model/command/rev_chatgpt.py b/model/command/rev_chatgpt.py
index 6780930c2..3f8399e8b 100644
--- a/model/command/rev_chatgpt.py
+++ b/model/command/rev_chatgpt.py
@@ -12,14 +12,14 @@ class CommandRevChatGPT(Command):
self.personality_str = ""
super().__init__(provider, global_object)
- def check_command(self,
+ async def check_command(self,
message: str,
session_id: str,
role: str,
platform: str,
message_obj):
self.platform = platform
- hit, res = super().check_command(
+ hit, res = await super().check_command(
message,
session_id,
role,
diff --git a/model/platform/qqchan.py b/model/platform/qqchan.py
index 217598ae4..ad3974f82 100644
--- a/model/platform/qqchan.py
+++ b/model/platform/qqchan.py
@@ -188,4 +188,9 @@ class QQChan():
_n = NakuruGuildMessage()
_n.channel_id = channel_id
self.send_qq_msg(_n, message_chain)
-
\ No newline at end of file
+
+ def send(self, message: NakuruGuildMessage, res: list):
+ '''
+ 同 send_qq_msg。回复频道消息
+ '''
+ self.send_qq_msg(message, res)
\ No newline at end of file
From f4222e0923406d727918ea5ba227fca1ff851808 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Tue, 21 Nov 2023 22:37:35 +0800
Subject: [PATCH 02/47] bugfixes
---
model/command/adapter/protocol_adapter.py | 1 +
model/platform/qqchan.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/model/command/adapter/protocol_adapter.py b/model/command/adapter/protocol_adapter.py
index 1f0f6030a..3e182d4d7 100644
--- a/model/command/adapter/protocol_adapter.py
+++ b/model/command/adapter/protocol_adapter.py
@@ -11,6 +11,7 @@ from .onebot.message_event import MessageEvent
from .onebot.message_segment import MessageSegment
from .nonebot.command_arg import CommandArg
from .onebot.bot import Bot
+from .nonebot.common import require
from nakuru import (
GuildMessage,
diff --git a/model/platform/qqchan.py b/model/platform/qqchan.py
index ad3974f82..a9e46ad8a 100644
--- a/model/platform/qqchan.py
+++ b/model/platform/qqchan.py
@@ -76,7 +76,7 @@ class QQChan():
ngm.sub_type = "normal"
ngm.message_id = message.id
- ngm.guild_id = int(message.channel_id)
+ ngm.guild_id = int(message.guild_id)
ngm.channel_id = int(message.channel_id)
ngm.user_id = int(message.author.id)
msg = []
From cbe761fc3385981ae7738cf91eb3ac0decc29ea4 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Wed, 7 Aug 2024 00:49:00 +0800
Subject: [PATCH 03/47] Update README.md
---
README.md | 24 ++++++++++++++++++------
1 file changed, 18 insertions(+), 6 deletions(-)
diff --git a/README.md b/README.md
index 130e24633..fa692b958 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
-
+
@@ -21,28 +21,40 @@
🌍 支持的消息平台
- QQ 群、QQ 频道(OneBot、QQ 官方接口)
-- Telegram(由 [astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件支持)
-- WeChat(微信) (由 [astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件支持)
+- Telegram([astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件)
+- WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件)
-🌍 支持的大模型一览:
+🌍 支持的大模型/底座:
- OpenAI GPT、DallE 系列
- Claude(由[LLMs插件](https://github.com/Soulter/llms)支持)
- HuggingChat(由[LLMs插件](https://github.com/Soulter/llms)支持)
- Gemini(由[LLMs插件](https://github.com/Soulter/llms)支持)
+- Ollama
+- 几乎所有已知模型(可接入 [OneAPI](https://astrbot.soulter.top/docs/docs/adavanced/one-api))
🌍 机器人支持的能力一览:
- 大模型对话、人格、网页搜索
-- 可视化管理面板
+- 可视化仪表盘
- 同时处理多平台消息
- 精确到个人的会话隔离
- 插件支持
- 文本转图片回复(Markdown)
-## 🧩 插件支持
+## 🧩 插件
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/center/docs/%E4%BD%BF%E7%94%A8/%E6%8F%92%E4%BB%B6)
+## ❤️ 贡献
+
+欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
+
+对于新功能的添加,请先通过 Issue 进行讨论。
+
+## 🔭 展望
+
+- [ ] 更多、更开放的 LLM Agent 能力
+
## ✨ Demo

From 933df5765411f7c16e71c64f1243776016c0b823 Mon Sep 17 00:00:00 2001
From: itgpt <136777961+itgpt-com@users.noreply.github.com>
Date: Thu, 8 Aug 2024 15:53:44 +0800
Subject: [PATCH 04/47] =?UTF-8?q?=E4=BC=98=E5=8C=96=20docker=20build?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.github/workflows/docker-image.yml | 45 +++++++++++++++++++++---------
1 file changed, 32 insertions(+), 13 deletions(-)
diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index 1828268bf..e374cfde8 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -4,20 +4,39 @@ on:
release:
types: [published]
workflow_dispatch:
+
jobs:
- publish-latest-docker-image:
+ publish-docker:
runs-on: ubuntu-latest
- name: Build and publish docker image
steps:
- - name: Checkout
- uses: actions/checkout@v2
- - name: Build image
- run: |
- git clone https://github.com/Soulter/AstrBot
- cd AstrBot
- docker build -t ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest .
- - name: Publish image
- run: |
- docker login -u ${{ secrets.DOCKER_HUB_USERNAME }} -p ${{ secrets.DOCKER_HUB_PASSWORD }}
- docker push ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
+ - name: 拉取源码
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 1
+
+ - name: 设置 QEMU
+ uses: docker/setup-qemu-action@v3
+
+ - name: 设置 Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: 登录到 DockerHub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKER_HUB_USERNAME }}
+ password: ${{ secrets.DOCKER_HUB_PASSWORD }}
+
+ - name: 构建和推送 Docker hub
+ uses: docker/build-push-action@v6
+ with:
+ context: .
+ platforms: linux/amd64,linux/arm64
+ push: true
+ tags: |
+ ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
+ ${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
+
+ - name: Post build notifications
+ run: echo "Docker image has been built and pushed successfully"
+
From 1d6ea2dbe6078d8a37c19e44dc0dbd607945d527 Mon Sep 17 00:00:00 2001
From: itgpt <136777961+itgpt-com@users.noreply.github.com>
Date: Thu, 8 Aug 2024 16:16:55 +0800
Subject: [PATCH 05/47] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=AB=AF=E5=8F=A3?=
=?UTF-8?q?=E8=BE=93=E5=87=BA?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
Dockerfile | 3 +++
1 file changed, 3 insertions(+)
diff --git a/Dockerfile b/Dockerfile
index 93c0a2914..9e027f258 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -5,4 +5,7 @@ COPY . /AstrBot/
RUN python -m pip install -r requirements.txt
+EXPOSE 6185
+EXPOSE 6186
+
CMD [ "python", "main.py" ]
From 0c679a0151f6a6d20cfeb1638289cefb4d063d51 Mon Sep 17 00:00:00 2001
From: itgpt <136777961+itgpt-com@users.noreply.github.com>
Date: Thu, 8 Aug 2024 16:21:30 +0800
Subject: [PATCH 06/47] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20.dockerignore=20?=
=?UTF-8?q?=E8=BF=87=E6=BB=A4=20docker=20cp=20=E4=B8=8D=E5=BF=85=E8=A6=81?=
=?UTF-8?q?=E6=96=87=E4=BB=B6=E3=80=82=E7=BC=A9=E5=B0=8F=E9=95=9C=E5=83=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.dockerignore | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
create mode 100644 .dockerignore
diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 000000000..ddad8690a
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,17 @@
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+# github acions
+.github/
+.*ignore
+# User-specific stuff
+.idea/
+# Byte-compiled / optimized / DLL files
+__pycache__/
+# Environments
+.env
+.venv
+env/
+venv*/
+ENV/
+.conda/
+README*.md
From 73dd4703b9fcaf7f41cd7ee355bee63428940a14 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Thu, 8 Aug 2024 22:15:05 +0800
Subject: [PATCH 07/47] Update .dockerignore
---
.dockerignore | 1 +
1 file changed, 1 insertion(+)
diff --git a/.dockerignore b/.dockerignore
index ddad8690a..f27cf068c 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -3,6 +3,7 @@
# github acions
.github/
.*ignore
+.git/
# User-specific stuff
.idea/
# Byte-compiled / optimized / DLL files
From 0f470cf96fc425c9b15037304bbd09abaadc0e63 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Fri, 9 Aug 2024 12:26:00 +0800
Subject: [PATCH 08/47] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index fa692b958..f40bc6e02 100644
--- a/README.md
+++ b/README.md
@@ -43,7 +43,7 @@
## 🧩 插件
-有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/center/docs/%E4%BD%BF%E7%94%A8/%E6%8F%92%E4%BB%B6)
+有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin)
## ❤️ 贡献
From 9db43ac5e68f3a455f82d7dd1af7eac6b225d3da Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 10 Aug 2024 02:35:54 -0400
Subject: [PATCH 09/47] =?UTF-8?q?feat:=20=E6=B3=A8=E5=86=8C=E6=8C=87?=
=?UTF-8?q?=E4=BB=A4=E6=94=AF=E6=8C=81=E5=BF=BD=E7=95=A5=E6=8C=87=E4=BB=A4?=
=?UTF-8?q?=E5=89=8D=E7=BC=80=EF=BC=9B=E5=BF=AB=E6=8D=B7=E4=B8=BB=E5=8A=A8?=
=?UTF-8?q?=E5=9B=9E=E5=A4=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/bootstrap.py | 1 +
model/command/manager.py | 21 +++++++++++++-
model/platform/__init__.py | 8 ++++++
model/platform/qq_aiocqhttp.py | 39 ++++++++++++++++++++++----
model/platform/qq_nakuru.py | 51 ++++++++++++++++++++++++++++++----
model/platform/qq_official.py | 10 +++++--
model/plugin/command.py | 5 ++--
type/command.py | 24 ++++++++--------
type/config.py | 2 +-
type/message_event.py | 26 ++++++++++-------
type/types.py | 35 +++++++++++++++++++++--
11 files changed, 181 insertions(+), 41 deletions(-)
diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py
index 0facf2037..af64e0dd6 100644
--- a/astrbot/bootstrap.py
+++ b/astrbot/bootstrap.py
@@ -78,6 +78,7 @@ class AstrBotBootstrap():
self.context.updator = self.updator
self.context.plugin_updator = self.plugin_manager.updator
self.context.message_handler = self.message_handler
+ self.context.command_manager = self.command_manager
# load plugins, plugins' commands.
self.load_plugins()
diff --git a/model/command/manager.py b/model/command/manager.py
index 3f7608c71..f6e90290c 100644
--- a/model/command/manager.py
+++ b/model/command/manager.py
@@ -21,6 +21,7 @@ class CommandMetadata():
plugin_metadata: PluginMetadata
handler: callable
use_regex: bool = False
+ ignore_prefix: bool = False
description: str = ""
class CommandManager():
@@ -35,6 +36,7 @@ class CommandManager():
priority: int,
handler: callable,
use_regex: bool = False,
+ ignore_prefix: bool = False,
plugin_metadata: PluginMetadata = None,
):
'''
@@ -53,6 +55,7 @@ class CommandManager():
plugin_metadata=plugin_metadata,
handler=handler,
use_regex=use_regex,
+ ignore_prefix=ignore_prefix,
description=description
)
if plugin_metadata:
@@ -75,9 +78,23 @@ class CommandManager():
priority=request.priority,
handler=request.handler,
use_regex=request.use_regex,
+ ignore_prefix=request.ignore_prefix,
plugin_metadata=plugin.metadata)
self.plugin_commands_waitlist = []
-
+
+ async def check_command_ignore_prefix(self, message_str: str) -> bool:
+ for _, command in self.commands:
+ command_metadata = self.commands_handler[command]
+ if command_metadata.ignore_prefix:
+ trig = False
+ if self.commands_handler[command].use_regex:
+ trig = self.command_parser.regex_match(message_str, command)
+ else:
+ trig = message_str.startswith(command)
+ if trig:
+ return True
+ return False
+
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
message_str = message_event.message_str
for _, command in self.commands:
@@ -89,6 +106,8 @@ class CommandManager():
if trig:
logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context)
+ if not command_result:
+ continue
if command_result.hit:
return command_result
diff --git a/model/platform/__init__.py b/model/platform/__init__.py
index 7eadf4270..1acac0fb5 100644
--- a/model/platform/__init__.py
+++ b/model/platform/__init__.py
@@ -3,6 +3,7 @@ from typing import Union, Any, List
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
from type.astrbot_message import AstrBotMessage
from type.command import CommandResult
+from type.astrbot_message import MessageType
class Platform():
@@ -30,6 +31,13 @@ class Platform():
发送消息(主动)
'''
pass
+
+ @abc.abstractmethod
+ async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
+ '''
+ 发送消息(主动)
+ '''
+ pass
def parse_message_outline(self, message: AstrBotMessage) -> str:
'''
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index 0788a5b2e..9fc73ee52 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -103,12 +103,16 @@ class AIOCQHTTP(Platform):
await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool:
- # if message chain contains Plain components or At components which points to self_id, return True
+ # if message chain contains Plain components or
+ # At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
return True
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True
+ # check commands which ignore prefix
+ if self.context.command_manager.check_command_ignore_prefix(message.message_str):
+ return True
# check nicks
if self.check_nick(message.message_str):
return True
@@ -129,14 +133,28 @@ class AIOCQHTTP(Platform):
else:
role = 'member'
+ # parse unified message origin
+ unified_msg_origin = None
+ assert isinstance(message.raw_message, Event)
+ if message.type == MessageType.GROUP_MESSAGE:
+ unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.raw_message.group_id}"
+ elif message.type == MessageType.FRIEND_MESSAGE:
+ unified_msg_origin = f"aiocqhttp:{message.type.value}:{message.sender.user_id}"
+
+ logger.debug(f"unified_msg_origin: {unified_msg_origin}")
+
# construct astrbot message event
- ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role)
+ ame = AstrMessageEvent.from_astrbot_message(message,
+ self.context,
+ "aiocqhttp",
+ message.session_id,
+ role, unified_msg_origin)
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
if not message_result: return
- await self.reply_msg(message, message_result.result_message)
+ await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback:
message_result.callback()
@@ -147,7 +165,8 @@ class AIOCQHTTP(Platform):
async def reply_msg(self,
message: AstrBotMessage,
- result_message: list):
+ result_message: list,
+ use_t2i: bool = None):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
@@ -160,7 +179,7 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
- if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
+ if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -223,4 +242,12 @@ class AIOCQHTTP(Platform):
'''
- await self._reply(target, result_message.message_chain)
\ No newline at end of file
+ await self._reply(target, result_message.message_chain)
+
+ async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
+ if message_type == MessageType.GROUP_MESSAGE:
+ await self.send_msg({'group_id': int(target)}, result_message)
+ elif message_type == MessageType.FRIEND_MESSAGE:
+ await self.send_msg({'user_id': int(target)}, result_message)
+ else:
+ raise Exception("aiocqhttp: 无法识别的消息类型。")
\ No newline at end of file
diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py
index d4052094d..f3311a8b7 100644
--- a/model/platform/qq_nakuru.py
+++ b/model/platform/qq_nakuru.py
@@ -78,6 +78,9 @@ class QQGOCQ(Platform):
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True
+ # check commands which ignore prefix
+ if self.context.command_manager.check_command_ignore_prefix(message.message_str):
+ return True
# check nicks
if self.check_nick(message.message_str):
return True
@@ -118,14 +121,34 @@ class QQGOCQ(Platform):
else:
role = 'member'
+ # parse unified message origin
+ unified_msg_origin = None
+ if message.type == MessageType.GROUP_MESSAGE:
+ assert isinstance(message.raw_message, GroupMessage)
+ unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.group_id}"
+ elif message.type == MessageType.FRIEND_MESSAGE:
+ assert isinstance(message.raw_message, FriendMessage)
+ unified_msg_origin = f"nakuru:{message.type.value}:{message.sender.user_id}"
+ elif message.type == MessageType.GUILD_MESSAGE:
+ assert isinstance(message.raw_message, GuildMessage)
+ unified_msg_origin = f"nakuru:{message.type.value}:{message.raw_message.channel_id}"
+
+ logger.debug(f"unified_msg_origin: {unified_msg_origin}")
+
+
# construct astrbot message event
- ame = AstrMessageEvent.from_astrbot_message(message, self.context, "gocq", session_id, role)
+ ame = AstrMessageEvent.from_astrbot_message(message,
+ self.context,
+ "nakuru",
+ session_id,
+ role,
+ unified_msg_origin)
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
if not message_result: return
- await self.reply_msg(message, message_result.result_message)
+ await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback:
message_result.callback()
@@ -135,7 +158,8 @@ class QQGOCQ(Platform):
async def reply_msg(self,
message: AstrBotMessage,
- result_message: List[BaseMessageComponent]):
+ result_message: List[BaseMessageComponent],
+ use_t2i: bool = None):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
@@ -152,7 +176,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
- if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
+ if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -213,6 +237,23 @@ class QQGOCQ(Platform):
guild_id 不是频道号。
'''
await self._reply(target, result_message.message_chain)
+
+ async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
+ '''
+ 以主动的方式给用户、群或者频道发送一条消息。
+
+ `message_type` 为 MessageType 枚举类型。
+
+ - 要发给 QQ 下的某个用户,请使用 MessageType.FRIEND_MESSAGE;
+ - 要发给某个群聊,请使用 MessageType.GROUP_MESSAGE;
+ - 要发给某个频道,请使用 MessageType.GUILD_MESSAGE。
+ '''
+ if message_type == MessageType.FRIEND_MESSAGE:
+ await self.send_msg({"user_id": int(target)}, result_message)
+ elif message_type == MessageType.GROUP_MESSAGE:
+ await self.send_msg({"group_id": int(target)}, result_message)
+ elif message_type == MessageType.GUILD_MESSAGE:
+ await self.send_msg({"channel_id": int(target)}, result_message)
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
abm = AstrBotMessage()
@@ -233,7 +274,7 @@ class QQGOCQ(Platform):
str(message.sender.user_id),
str(message.sender.nickname)
)
- abm.tag = "gocq"
+ abm.tag = "nakuru"
abm.message = message.message
return abm
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index 5ca3e301b..dc28453b8 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -222,7 +222,7 @@ class QQOfficial(Platform):
if not message_result:
return
- ret = await self.reply_msg(message, message_result.result_message)
+ ret = await self.reply_msg(message, message_result.result_message, message_result.use_t2i)
if message_result.callback:
message_result.callback()
@@ -234,7 +234,8 @@ class QQOfficial(Platform):
async def reply_msg(self,
message: AstrBotMessage,
- result_message: List[BaseMessageComponent]):
+ result_message: List[BaseMessageComponent],
+ use_t2i: bool = None):
'''
回复频道消息
'''
@@ -249,7 +250,7 @@ class QQOfficial(Platform):
msg_ref = None
rendered_images = []
- if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
+ if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list):
@@ -388,6 +389,9 @@ class QQOfficial(Platform):
if image_path:
payload['file_image'] = image_path
await self._reply(**payload)
+
+ async def send_msg_new(self, message_type: MessageType, target: str, result_message: CommandResult):
+ raise NotImplementedError("qqofficial 不支持此方法。")
def wait_for_message(self, channel_id: int) -> AstrBotMessage:
'''
diff --git a/model/plugin/command.py b/model/plugin/command.py
index 3321d52c7..1e4d8fab9 100644
--- a/model/plugin/command.py
+++ b/model/plugin/command.py
@@ -15,12 +15,13 @@ class CommandRegisterRequest():
handler: Callable
use_regex: bool = False
plugin_name: str = None
+ ignore_prefix: bool = False
class PluginCommandBridge():
def __init__(self, cached_plugins: RegisteredPlugins):
self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
self.cached_plugins = cached_plugins
- def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False):
- self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name))
+ def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False, ignore_prefix=False):
+ self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name, ignore_prefix))
\ No newline at end of file
diff --git a/type/command.py b/type/command.py
index a8723063f..ac9ca0d50 100644
--- a/type/command.py
+++ b/type/command.py
@@ -2,7 +2,6 @@ from typing import Union, List, Callable
from dataclasses import dataclass
from nakuru.entities.components import Plain, Image
-
@dataclass
class CommandItem():
'''
@@ -19,12 +18,17 @@ class CommandResult():
用于在Command中返回多个值
'''
- def __init__(self, hit: bool = True, success: bool = True, message_chain: list = [], command_name: str = "unknown_command") -> None:
+ def __init__(self,
+ hit: bool = True,
+ success: bool = True,
+ message_chain: list = [],
+ command_name: str = "unknown_command",
+ use_t2i: bool = None) -> None:
self.hit = hit
self.success = success
self.message_chain = message_chain
self.command_name = command_name
- self.is_use_t2i = None # default
+ self.is_use_t2i = use_t2i
def message(self, message: str):
'''
@@ -63,14 +67,12 @@ class CommandResult():
self.message_chain = [Image.fromFileSystem(path), ]
return self
- # def use_t2i(self, use_t2i: bool):
- # '''
- # 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
-
- # CommandResult().use_t2i(False)
- # '''
- # self.is_use_t2i = use_t2i
- # return self
+ def use_t2i(self, use_t2i: bool):
+ '''
+ 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
+ '''
+ self.is_use_t2i = use_t2i
+ return self
def _result_tuple(self):
return (self.success, self.message_chain, self.command_name)
diff --git a/type/config.py b/type/config.py
index e13bc9262..6d7fa437c 100644
--- a/type/config.py
+++ b/type/config.py
@@ -1,4 +1,4 @@
-VERSION = '3.3.7'
+VERSION = '3.3.8'
DEFAULT_CONFIG = {
"qqbot": {
diff --git a/type/message_event.py b/type/message_event.py
index 222ac91e0..5ff8a7007 100644
--- a/type/message_event.py
+++ b/type/message_event.py
@@ -2,7 +2,14 @@ from typing import List, Union, Optional
from dataclasses import dataclass
from type.register import RegisteredPlatform
from type.types import Context
-from type.astrbot_message import AstrBotMessage
+from type.astrbot_message import AstrBotMessage, MessageType
+
+@dataclass
+class MessageResult():
+ result_message: Union[str, list]
+ is_command_call: Optional[bool] = False
+ use_t2i: Optional[bool] = None # None 为跟随用户设置
+ callback: Optional[callable] = None
class AstrMessageEvent():
@@ -12,7 +19,8 @@ class AstrMessageEvent():
platform: RegisteredPlatform,
role: str,
context: Context,
- session_id: str = None):
+ session_id: str = None,
+ unified_msg_origin: str = None):
'''
AstrBot 消息事件。
@@ -22,6 +30,7 @@ class AstrMessageEvent():
`role`: 角色,`admin` or `member`
`context`: 全局对象
`session_id`: 会话id
+ `unified_msg_origin`: 统一消息来源
'''
self.context = context
self.message_str = message_str
@@ -29,24 +38,21 @@ class AstrMessageEvent():
self.platform = platform
self.role = role
self.session_id = session_id
+ self.unified_msg_origin = unified_msg_origin
def from_astrbot_message(message: AstrBotMessage,
context: Context,
platform_name: str,
session_id: str,
- role: str = "member"):
+ role: str = "member",
+ unified_msg_origin: str = None):
ame = AstrMessageEvent(message.message_str,
message,
context.find_platform(platform_name),
role,
context,
- session_id)
+ session_id,
+ unified_msg_origin)
return ame
-@dataclass
-class MessageResult():
- result_message: Union[str, list]
- is_command_call: Optional[bool] = False
- use_t2i: Optional[bool] = None # None 为跟随用户设置
- callback: Optional[callable] = None
diff --git a/type/types.py b/type/types.py
index 195e2b1a6..0e3fee1cd 100644
--- a/type/types.py
+++ b/type/types.py
@@ -8,6 +8,8 @@ from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator
+from type.command import CommandResult
+from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
@@ -40,6 +42,8 @@ class Context:
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
+
+ self.command_manager = None
# useless
self.reply_prefix = ""
@@ -50,7 +54,8 @@ class Context:
description: str,
priority: int,
handler: callable,
- use_regex: bool = False):
+ use_regex: bool = False,
+ ignore_prefix: bool = False):
'''
注册插件指令。
@@ -60,8 +65,19 @@ class Context:
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。
+ @param ignore_prefix: 是否忽略前缀。默认为 False。设置为 True 后,将不会检查用户设置的前缀。
+
+ .. Example::
+
+ ignore_prefix = False 时,用户输入 "/help" 时,会被识别为 "help" 指令。如果 ignore_prefix = True,则用户输入 "help" 也会被识别为 "help" 指令。
'''
- self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex)
+ self.plugin_command_bridge.register_command(plugin_name,
+ command_name,
+ description,
+ priority,
+ handler,
+ use_regex,
+ ignore_prefix)
def register_task(self, coro: Awaitable, task_name: str):
'''
@@ -87,3 +103,18 @@ class Context:
return platform
raise ValueError("couldn't find the platform you specified")
+
+ async def send_message(self, unified_msg_origin: str, message: CommandResult):
+ '''
+ 发送消息。
+
+ `unified_msg_origin`: 统一消息来源
+ `message`: 消息内容
+ '''
+ l = unified_msg_origin.split(":")
+ if len(l) != 3:
+ raise ValueError("Invalid unified_msg_origin")
+ platform_name, message_type, id = l
+ platform = self.find_platform(platform_name)
+ await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
+
\ No newline at end of file
From 1df83addfc42604685d2e5b3726a642610249b34 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sat, 10 Aug 2024 14:59:00 +0800
Subject: [PATCH 10/47] update: add gcc
---
Dockerfile | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/Dockerfile b/Dockerfile
index 9e027f258..055d37bae 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -3,6 +3,15 @@ WORKDIR /AstrBot
COPY . /AstrBot/
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ gcc \
+ build-essential \
+ python3-dev \
+ libffi-dev \
+ libssl-dev \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
RUN python -m pip install -r requirements.txt
EXPOSE 6185
From f02731055ee30d042d05ef5e13794b798d0a34b7 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 10 Aug 2024 03:24:53 -0400
Subject: [PATCH 11/47] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=8F=92?=
=?UTF-8?q?=E4=BB=B6=E5=90=AF=E7=94=A8=E5=BF=BD=E7=95=A5=E5=89=8D=E7=BC=80?=
=?UTF-8?q?=E4=B9=8B=E5=90=8E=E5=8F=AF=E8=83=BD=E7=9A=84=E9=80=BB=E8=BE=91?=
=?UTF-8?q?=E5=86=B2=E7=AA=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/message/handler.py | 9 +++++++--
model/platform/qq_aiocqhttp.py | 17 ++++++++++-------
model/platform/qq_nakuru.py | 16 +++++++++-------
type/message_event.py | 11 ++++++++---
4 files changed, 34 insertions(+), 19 deletions(-)
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index 12cd57419..49d8be400 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -134,8 +134,8 @@ class MessageHandler():
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
# TODO: this should be configurable
- if not message.message_str:
- return MessageResult("Hi~")
+ # if not message.message_str:
+ # return MessageResult("Hi~")
# check the rate limit
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
@@ -158,6 +158,11 @@ class MessageHandler():
use_t2i=cmd_res.is_use_t2i
)
+ # next is the LLM part
+
+ if message.only_command:
+ return
+
# check if the message is a llm-wake-up command
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index 9fc73ee52..89a699f69 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -106,23 +106,24 @@ class AIOCQHTTP(Platform):
# if message chain contains Plain components or
# At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
- return True
+ return True, "friend"
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
- return True
+ return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
- return True
+ return True, "command"
# check nicks
if self.check_nick(message.message_str):
- return True
- return False
+ return True, "nick"
+ return False, "none"
async def handle_msg(self, message: AstrBotMessage):
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
- if not self.pre_check(message):
+ ok, reason = self.pre_check(message)
+ if not ok:
return
# 解析 role
@@ -148,7 +149,9 @@ class AIOCQHTTP(Platform):
self.context,
"aiocqhttp",
message.session_id,
- role, unified_msg_origin)
+ role,
+ unified_msg_origin,
+ reason == "command") # only_command
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py
index f3311a8b7..a40cb987c 100644
--- a/model/platform/qq_nakuru.py
+++ b/model/platform/qq_nakuru.py
@@ -74,17 +74,17 @@ class QQGOCQ(Platform):
def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
- return True
+ return True, "friend"
for comp in message.message:
if isinstance(comp, At) and str(comp.qq) == message.self_id:
- return True
+ return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
- return True
+ return True, "command"
# check nicks
if self.check_nick(message.message_str):
- return True
- return False
+ return True, "nick"
+ return False, "none"
def run(self):
coro = self.client._run()
@@ -98,7 +98,8 @@ class QQGOCQ(Platform):
(GroupMessage, FriendMessage, GuildMessage))
# 判断是否响应消息
- if not self.pre_check(message):
+ ok, reason = self.pre_check(message)
+ if not ok:
return
# 解析 session_id
@@ -142,7 +143,8 @@ class QQGOCQ(Platform):
"nakuru",
session_id,
role,
- unified_msg_origin)
+ unified_msg_origin,
+ reason == 'command') # only_command
# transfer control to message handler
message_result = await self.message_handler.handle(ame)
diff --git a/type/message_event.py b/type/message_event.py
index 5ff8a7007..dc0221897 100644
--- a/type/message_event.py
+++ b/type/message_event.py
@@ -20,7 +20,8 @@ class AstrMessageEvent():
role: str,
context: Context,
session_id: str = None,
- unified_msg_origin: str = None):
+ unified_msg_origin: str = None,
+ only_command: bool = False):
'''
AstrBot 消息事件。
@@ -31,6 +32,7 @@ class AstrMessageEvent():
`context`: 全局对象
`session_id`: 会话id
`unified_msg_origin`: 统一消息来源
+ `only_command`: 是否只处理指令,而不使用 LLM 回复
'''
self.context = context
self.message_str = message_str
@@ -39,13 +41,15 @@ class AstrMessageEvent():
self.role = role
self.session_id = session_id
self.unified_msg_origin = unified_msg_origin
+ self.only_command = only_command
def from_astrbot_message(message: AstrBotMessage,
context: Context,
platform_name: str,
session_id: str,
role: str = "member",
- unified_msg_origin: str = None):
+ unified_msg_origin: str = None,
+ only_command: bool = False):
ame = AstrMessageEvent(message.message_str,
message,
@@ -53,6 +57,7 @@ class AstrMessageEvent():
role,
context,
session_id,
- unified_msg_origin)
+ unified_msg_origin,
+ only_command=only_command)
return ame
From 95a8cc9498b94bb08c80bef79029a1f7260e7985 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 10 Aug 2024 04:13:24 -0400
Subject: [PATCH 12/47] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E9=83=A8?=
=?UTF-8?q?=E5=88=86=E5=AD=97=E6=AE=B5=E6=9C=AA=E6=9B=B4=E6=96=B0=E5=AF=BC?=
=?UTF-8?q?=E8=87=B4=E7=9A=84=E9=94=99=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
model/platform/manager.py | 4 ++--
model/platform/qq_nakuru.py | 18 +++++++++++++++---
model/platform/qq_official.py | 2 +-
3 files changed, 18 insertions(+), 6 deletions(-)
diff --git a/model/platform/manager.py b/model/platform/manager.py
index 5ca217346..0f94765dc 100644
--- a/model/platform/manager.py
+++ b/model/platform/manager.py
@@ -58,7 +58,7 @@ class PlatformManager():
try:
qq_gocq = QQGOCQ(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform(
- platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
+ platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
await qq_gocq.run()
except BaseException as e:
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
@@ -81,7 +81,7 @@ class PlatformManager():
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(self.context, self.msg_handler)
self.context.platforms.append(RegisteredPlatform(
- platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal"))
+ platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
return qqchannel_bot.run()
except BaseException as e:
logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e))
diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py
index a40cb987c..bf1b2419b 100644
--- a/model/platform/qq_nakuru.py
+++ b/model/platform/qq_nakuru.py
@@ -195,14 +195,26 @@ class QQGOCQ(Platform):
message_chain = [Plain(text=message_chain), ]
is_dict = isinstance(source, dict)
- if source.type == "GuildMessage":
+
+ typ = None
+ if is_dict:
+ if "group_id" in source:
+ typ = "GroupMessage"
+ elif "user_id" in source:
+ typ = "FriendMessage"
+ elif "guild_id" in source:
+ typ = "GuildMessage"
+ else:
+ typ = source.type
+
+ if typ == "GuildMessage":
guild_id = source['guild_id'] if is_dict else source.guild_id
chan_id = source['channel_id'] if is_dict else source.channel_id
await self.client.sendGuildChannelMessage(guild_id, chan_id, message_chain)
- elif source.type == "FriendMessage":
+ elif typ == "FriendMessage":
user_id = source['user_id'] if is_dict else source.user_id
await self.client.sendFriendMessage(user_id, message_chain)
- elif source.type == "GroupMessage":
+ elif typ == "GroupMessage":
group_id = source['group_id'] if is_dict else source.group_id
# 过长时forward发送
plain_text_len = 0
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index dc28453b8..a19b6024a 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -112,7 +112,7 @@ class QQOfficial(Platform):
abm.timestamp = int(time.time())
abm.raw_message = message
abm.message_id = message.id
- abm.tag = "qqchan"
+ abm.tag = "qqofficial"
msg: List[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
From a876efb95f6c86fc6550e6166fb7f46473e88c0c Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 10 Aug 2024 04:35:07 -0400
Subject: [PATCH 13/47] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=E5=90=8E?=
=?UTF-8?q?=E8=A6=86=E7=9B=96=E6=96=87=E4=BB=B6=E8=B7=AF=E5=BE=84=E9=94=99?=
=?UTF-8?q?=E8=AF=AF?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
util/updator/astrbot_updator.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/util/updator/astrbot_updator.py b/util/updator/astrbot_updator.py
index eccdf2089..bc80a7bd5 100644
--- a/util/updator/astrbot_updator.py
+++ b/util/updator/astrbot_updator.py
@@ -9,7 +9,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotUpdator(RepoZipUpdator):
def __init__(self):
- self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
+ self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
def terminate_child_processes(self):
From 121c40f273bd3611254a45e08b80a8b00ba64a6a Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 11 Aug 2024 01:49:33 -0400
Subject: [PATCH 14/47] perf: raise error when badrequest
---
model/provider/openai_official.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index b5eefcd23..fc02a77d5 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -359,7 +359,7 @@ class ProviderOpenAIOfficial(Provider):
logger.warn(f"OpenAI 请求异常:{e}。")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
- retry += 1
+ raise e
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
From 266da0a9d89a525d9694135967b59bab5668c10b Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 11 Aug 2024 02:30:49 -0400
Subject: [PATCH 15/47] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E9=87=8D?=
=?UTF-8?q?=E5=90=AF=E6=97=B6=20aiocqhttp=20=E6=B2=A1=E6=9C=89=E6=AD=A3?=
=?UTF-8?q?=E5=B8=B8=E9=80=80=E5=87=BA=E5=AF=BC=E8=87=B4=E7=AB=AF=E5=8F=A3?=
=?UTF-8?q?=E5=8D=A0=E7=94=A8=E7=9A=84=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
astrbot/bootstrap.py | 3 +++
dashboard/server.py | 4 ++--
model/command/internal_handler.py | 4 ++--
model/platform/qq_aiocqhttp.py | 2 +-
type/types.py | 1 +
util/updator/astrbot_updator.py | 6 ++++--
6 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py
index af64e0dd6..225ab08a9 100644
--- a/astrbot/bootstrap.py
+++ b/astrbot/bootstrap.py
@@ -100,6 +100,9 @@ class AstrBotBootstrap():
try:
result = await task
return result
+ except asyncio.CancelledError:
+ logger.info(f"{task.get_name()} 任务已取消。")
+ return
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。")
diff --git a/dashboard/server.py b/dashboard/server.py
index e83a37161..60e880b51 100644
--- a/dashboard/server.py
+++ b/dashboard/server.py
@@ -311,7 +311,7 @@ class AstrBotDashBoard():
latest = False
try:
self.astrbot_updator.update(latest=latest, version=version)
- threading.Thread(target=self.astrbot_updator._reboot, args=(3, )).start()
+ threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
return Response(
status="success",
message="更新成功,机器人将在 3 秒内重启。",
@@ -374,7 +374,7 @@ class AstrBotDashBoard():
self.dashboard_data, self.context.config_helper.get_all())
# 重启
threading.Thread(target=self.astrbot_updator._reboot,
- args=(2, ), daemon=True).start()
+ args=(2, self.context), daemon=True).start()
except Exception as e:
raise e
diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py
index 377cd9df8..422c226f8 100644
--- a/model/command/internal_handler.py
+++ b/model/command/internal_handler.py
@@ -117,11 +117,11 @@ class InternalCommandHandler:
success=False,
message_chain="你没有权限使用该指令",
)
- context.updator._reboot(5)
+ context.updator._reboot(3, context)
return CommandResult(
hit=True,
success=True,
- message_chain="AstrBot 将在 5s 后重启。",
+ message_chain="AstrBot 将在 3s 后重启。",
)
def plugin(self, message: AstrMessageEvent, context: Context):
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index 89a699f69..31e025341 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -99,7 +99,7 @@ class AIOCQHTTP(Platform):
return bot
async def shutdown_trigger_placeholder(self):
- while True:
+ while self.context.running:
await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool:
diff --git a/type/types.py b/type/types.py
index 0e3fee1cd..06aad685f 100644
--- a/type/types.py
+++ b/type/types.py
@@ -44,6 +44,7 @@ class Context:
self.ext_tasks: List[Task] = []
self.command_manager = None
+ self.running = True
# useless
self.reply_prefix = ""
diff --git a/util/updator/astrbot_updator.py b/util/updator/astrbot_updator.py
index bc80a7bd5..b6fbd9243 100644
--- a/util/updator/astrbot_updator.py
+++ b/util/updator/astrbot_updator.py
@@ -30,9 +30,11 @@ class AstrBotUpdator(RepoZipUpdator):
except psutil.NoSuchProcess:
pass
- def _reboot(self, delay: int = None):
- if delay: time.sleep(delay)
+ def _reboot(self, delay: int = None, context = None):
+ # if delay: time.sleep(delay)
py = sys.executable
+ context.running = False
+ time.sleep(3)
self.terminate_child_processes()
py = py.replace(" ", "\\ ")
try:
From 0633e7f25f3375e5dcf2904ea87301bf65ef1588 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 11 Aug 2024 03:55:31 -0400
Subject: [PATCH 16/47] perf: improve the effects of local function-calling
---
astrbot/message/handler.py | 2 +-
model/provider/openai_official.py | 2 +
util/agent/func_call.py | 210 +++++++-----------------------
util/agent/web_searcher.py | 36 ++---
4 files changed, 69 insertions(+), 181 deletions(-)
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index 49d8be400..5d460182f 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -189,7 +189,7 @@ class MessageHandler():
try:
if web_search:
- llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
+ llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
else:
llm_result = await provider.text_chat(
prompt=msg_plain,
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index fc02a77d5..31be2873b 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -370,6 +370,8 @@ class ProviderOpenAIOfficial(Provider):
await self.switch_to_next_key()
rate_limit_retry += 1
time.sleep(1)
+ except NotFoundError as e:
+ raise e
except Exception as e:
retry += 1
if retry >= 3:
diff --git a/util/agent/func_call.py b/util/agent/func_call.py
index ffacf242b..5bb602781 100644
--- a/util/agent/func_call.py
+++ b/util/agent/func_call.py
@@ -1,9 +1,7 @@
-
+from model.provider.provider import Provider
import json
-import util.general_utils as gu
-
import time
-
+import textwrap
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
@@ -22,14 +20,11 @@ class FuncNotFoundError(Exception):
class FuncCall():
- def __init__(self, provider) -> None:
+ def __init__(self, provider: Provider) -> None:
self.func_list = []
self.provider = provider
- def add_func(self, name: str = None, func_args: list = None, desc: str = None, func_obj=None) -> None:
- if name == None or func_args == None or desc == None or func_obj == None:
- raise FuncCallJsonFormatError(
- "name, func_args, desc must be provided.")
+ def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
params = {
"type": "object", # hardcore here
"properties": {}
@@ -47,7 +42,7 @@ class FuncCall():
}
self.func_list.append(self._func)
- def func_dump(self, intent: int = 2) -> str:
+ def func_dump(self) -> str:
_l = []
for f in self.func_list:
_l.append({
@@ -55,7 +50,7 @@ class FuncCall():
"parameters": f["parameters"],
"description": f["description"],
})
- return json.dumps(_l, indent=intent, ensur_ascii=False)
+ return json.dumps(_l, ensure_ascii=False)
def get_func(self) -> list:
_l = []
@@ -70,64 +65,36 @@ class FuncCall():
})
return _l
- def func_call(self, question, func_definition, is_task=False, tasks=None, taskindex=-1, is_summary=True, session_id=None):
+ async def func_call(self, question: str, func_definition: str, session_id: str=None):
- funccall_prompt = """
-我正实现function call功能,该功能旨在让你变成给定的问题到给定的函数的解析器(意味着你不是创造函数)。
-下面会给你提供可能用到的函数相关信息和一个问题,你需要将其转换成给定的函数调用。
-- 你的返回信息只含json,请严格仿照以下内容(不含注释),必须含有`res`,`func_call`字段:
-```
-{
- "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
- "func_call": [ // 这是一个数组,里面包含了所有的函数调用,如果没有函数调用,那么这个数组是空数组。
- {
- "res": string // 如果没有找到对应的函数,那么你可以在这里正常输出内容。如果有,这里是空字符串。
- "name": str, // 函数的名字
- "args_type": {
- "arg1": str, // 函数的参数的类型
- "arg2": str,
- ...
- },
- "args": {
- "arg1": any, // 函数的参数
- "arg2": any,
- ...
- }
- },
- ... // 可能在这个问题中会有多个函数调用
- ],
-}
-```
-- 如果用户的要求较复杂,允许返回多个函数调用,但需保证这些函数调用的顺序正确。
-- 当问题没有提到给定的函数时,相当于提问方不打算使用function call功能,这时你可以在res中正常输出这个问题的回答(以AI的身份正常回答该问题,并将答案输出在res字段中,回答不要涉及到任何函数调用的内容,就只是正常讨论这个问题。)
+ prompt = textwrap.dedent(f"""
+ ROLE:
+ 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
-提供的函数是:
+ TOOLS:
+ 可用的函数列表:
-"""
+ {func_definition}
- prompt = f"{funccall_prompt}\n```\n{func_definition}\n```\n"
- prompt += f"""
-用户的提问是:
-```
-{question}
-```
-"""
+ LIMIT:
+ 1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
+ 2. 你的 Json 返回的格式如下:`[{{"name": "
", "args": }}, ...]`。参数根据上面提供的函数列表中的参数来填写。
+ 3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
+ 4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
- # if is_task:
- # # task_prompt = f"\n任务列表为{str(tasks)}\n你目前进行到了任务{str(taskindex)}, **你不需要重新进行已经进行过的任务, 不要生成已经进行过的**"
- # prompt += task_prompt
+ EXAMPLE:
+ 1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
- # provider.forget()
+ 用户的提问是:{question}
+ """)
_c = 0
while _c < 3:
try:
- res = self.provider.text_chat(prompt=prompt, session_id=session_id)
+ res = await self.provider.text_chat(prompt, session_id)
+ print(res)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
- gu.log("REVGPT func_call json result",
- bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
- print(res)
res = json.loads(res)
break
except Exception as e:
@@ -136,112 +103,25 @@ class FuncCall():
raise e
if "The message you submitted was too long" in str(e):
raise e
+
+ if 'res' in res and not res['res']:
+ return "", False
- invoke_func_res = ""
-
- if "func_call" in res and len(res["func_call"]) > 0:
- task_list = res["func_call"]
-
- invoke_func_res_list = []
-
- for res in task_list:
- # 说明有函数调用
- func_name = res["name"]
- # args_type = res["args_type"]
- args = res["args"]
- # 调用函数
- # func = eval(func_name)
- func_target = None
- for func in self.func_list:
- if func["name"] == func_name:
- func_target = func["func_obj"]
- break
- if func_target == None:
- raise FuncNotFoundError(
- f"Request function {func_name} not found.")
- t_res = str(func_target(**args))
- invoke_func_res += f"{func_name} 调用结果:\n```\n{t_res}\n```\n"
- invoke_func_res_list.append(invoke_func_res)
- gu.log(f"[FUNC| {func_name} invoked]",
- bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
- # print(str(t_res))
-
- if is_summary:
-
- # 生成返回结果
- after_prompt = """
-有以下内容:"""+invoke_func_res+"""
-请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
-用户的提问是:
-```""" + question + """```
-- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
-- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
-```json
-{
- "res": string, // 回答的内容
- "func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
-}
-```
-- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。"""
-
- _c = 0
- while _c < 5:
- try:
- res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
- # 截取```之间的内容
- gu.log(
- "DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
- print(res)
- gu.log(
- "DEBUG END", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
- if res.find('```') != -1:
- res = res[res.find('```json') +
- 7: res.rfind('```')]
- gu.log("REVGPT after_func_call json result",
- bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
- after_prompt_res = res
- after_prompt_res = json.loads(after_prompt_res)
- break
- except Exception as e:
- _c += 1
- if _c == 5:
- raise e
- if "The message you submitted was too long" in str(e):
- # 如果返回的内容太长了,那么就截取一部分
- time.sleep(3)
- invoke_func_res = invoke_func_res[:int(
- len(invoke_func_res) / 2)]
- after_prompt = """
-函数返回以下内容:"""+invoke_func_res+"""
-请以AI助手的身份结合返回的内容对用户提问做详细全面的回答。
-用户的提问是:
-```""" + question + """```
-- 在res字段中,不要输出函数的返回值,也不要针对返回值的字段进行分析,也不要输出用户的提问,而是理解这一段返回的结果,并以AI助手的身份回答问题,只需要输出回答的内容,不需要在回答的前面加上身份词。
-- 你的返回信息必须只能是json,且需严格遵循以下内容(不含注释):
-```json
-{
- "res": string, // 回答的内容
- "func_call_again": bool // 如果函数返回的结果有错误或者问题,可将其设置为true,否则为false
-}
-```
-- 如果func_call_again为true,res请你设为空值,否则请你填写回答的内容。"""
- else:
- raise e
-
- if "func_call_again" in after_prompt_res and after_prompt_res["func_call_again"]:
- # 如果需要重新调用函数
- # 重新调用函数
- gu.log("REVGPT func_call_again",
- bg=gu.BG_COLORS["purple"], fg=gu.FG_COLORS["white"])
- res = self.func_call(question, func_definition)
- return res, True
-
- gu.log("REVGPT func callback:",
- bg=gu.BG_COLORS["green"], fg=gu.FG_COLORS["white"])
- # print(after_prompt_res["res"])
- return after_prompt_res["res"], True
- else:
- return str(invoke_func_res_list), True
- else:
- # print(res["res"])
- return res["res"], False
+ tool_call_result = []
+ for tool in res:
+ # 说明有函数调用
+ func_name = tool["name"]
+ args = tool["args"]
+ # 调用函数
+ tool_callable = None
+ for func in self.func_list:
+ if func["name"] == func_name:
+ tool_callable = func["func_obj"]
+ break
+ if not tool_callable:
+ raise FuncNotFoundError(
+ f"Request function {func_name} not found.")
+ ret = await tool_callable(**args)
+ if ret:
+ tool_call_result.append(str(ret))
+ return tool_call_result, True
diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py
index 6badf8188..c8634c2b7 100644
--- a/util/agent/web_searcher.py
+++ b/util/agent/web_searcher.py
@@ -1,13 +1,13 @@
import traceback
import random
import json
-import asyncio
import aiohttp
import os
from readability import Document
from bs4 import BeautifulSoup
from openai.types.chat.chat_completion_message_tool_call import Function
+from openai._exceptions import *
from util.agent.func_call import FuncCall
from util.websearch.config import HEADERS, USER_AGENTS
from util.websearch.bing import Bing
@@ -100,9 +100,9 @@ async def fetch_website_content(url):
return ret
-async def web_search(prompt, provider: Provider, session_id, official_fc=False):
+async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
'''
- official_fc: 使用官方 function-calling
+ @param official_fc: 使用官方 function-calling
'''
new_func_call = FuncCall(provider)
@@ -127,9 +127,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
function_invoked_ret = ""
if official_fc:
# we use official function-calling
- result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
+ try:
+ result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
+ except BadRequestError as e:
+ # seems dont support function-calling
+ logger.error(f"error: {e}. Try to use local function-calling implementation")
+ return await web_search(prompt, provider, session_id, official_fc=False)
if isinstance(result, Function):
- logger.debug(f"web_searcher - function-calling: {result}")
+ logger.debug(f"function-calling: {result}")
func_obj = None
for i in new_func_call.func_list:
if i["name"] == result.name:
@@ -152,30 +157,31 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
args = {
'question': prompt,
'func_definition': new_func_call.func_dump(),
- 'is_task': False,
- 'is_summary': False,
}
- function_invoked_ret, has_func = await asyncio.to_thread(new_func_call.func_call, **args)
+ function_invoked_ret, has_func = await new_func_call.func_call(**args)
+
+ if not has_func:
+ return await provider.text_chat(prompt, session_id)
+
except BaseException as e:
- res = await provider.text_chat(prompt) + "\n(网页搜索失败, 此为默认回复)"
- return res
- has_func = True
+ logger.error(traceback.format_exc())
+ return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
if has_func:
- await provider.forget(session_id=session_id, )
+ await provider.forget(session_id=session_id)
summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
-2. 简单地发表你对这个问题的简略看法。
+2. 简单地发表你对这个问题的看法。
# 例子
1. 从网上的信息来看,可以知道...我个人认为...你觉得呢?
2. 根据网上的最新信息,可以得知...我觉得...你怎么看?
# 限制
-1. 限制在 200 字以内;
+1. 限制在 200-300 字;
2. 请**直接输出总结**,不要输出多余的内容和提示语。
-
+
# 相关材料
{function_invoked_ret}"""
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
From 8d95e67b5ae303d9b7999749e620b726e28d7548 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sun, 11 Aug 2024 17:13:49 +0800
Subject: [PATCH 17/47] Update README.md
---
README.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/README.md b/README.md
index f40bc6e02..377037db1 100644
--- a/README.md
+++ b/README.md
@@ -45,6 +45,10 @@
有关插件的使用和列表请移步:[AstrBot 文档 - 插件](https://astrbot.soulter.top/docs/get-started/plugin)
+## 云部署
+
+[](https://repl.it/github/Soulter/AstrBot)
+
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
From 141c91301f0c6f2cd2642dae8756aebaa9118b92 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 11 Aug 2024 12:01:40 -0400
Subject: [PATCH 18/47] perf: Improve sleep time handling in QQOfficial and
ProviderOpenAIOfficial
---
model/platform/qq_official.py | 2 +-
model/provider/openai_official.py | 10 ++++------
util/general_utils.py | 30 ------------------------------
3 files changed, 5 insertions(+), 37 deletions(-)
delete mode 100644 util/general_utils.py
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index a19b6024a..6911cb367 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -408,4 +408,4 @@ class QQOfficial(Platform):
cnt += 1
if cnt > 300:
raise Exception("等待消息超时。")
- time.sleep(1)()
+ time.sleep(1)
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index 31be2873b..4914b9a8c 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -1,5 +1,5 @@
import os
-import sys
+import asyncio
import json
import time
import tiktoken
@@ -14,8 +14,6 @@ from openai._exceptions import *
from astrbot.persist.helper import dbConn
from model.provider.provider import Provider
-from util import general_utils as gu
-from util.cmd_config import CmdConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
@@ -369,7 +367,7 @@ class ProviderOpenAIOfficial(Provider):
logger.error(f"OpenAI API Key {self.chosen_api_key} 达到请求速率限制或者官方服务器当前超载。详细原因:{e}")
await self.switch_to_next_key()
rate_limit_retry += 1
- time.sleep(1)
+ await asyncio.sleep(1)
except NotFoundError as e:
raise e
except Exception as e:
@@ -383,7 +381,7 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(traceback.format_exc())
logger.warning(f"OpenAI 请求失败:{e}。重试第 {retry} 次。")
- time.sleep(1)
+ await asyncio.sleep(1)
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
@@ -454,7 +452,7 @@ class ProviderOpenAIOfficial(Provider):
logger.error(traceback.format_exc())
raise Exception(f"OpenAI 图片生成请求失败:{e}。重试次数已达到上限。")
logger.warning(f"OpenAI 图片生成请求失败:{e}。重试第 {retry} 次。")
- time.sleep(1)
+ await asyncio.sleep(1)
async def forget(self, session_id=None, keep_system_prompt: bool=False) -> bool:
if session_id is None: return False
diff --git a/util/general_utils.py b/util/general_utils.py
deleted file mode 100644
index 270faf418..000000000
--- a/util/general_utils.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import time
-import asyncio
-import requests
-import json
-import sys
-import psutil
-
-from type.types import Context
-from SparkleLogging.utils.core import LogManager
-from logging import Logger
-
-logger: Logger = LogManager.GetLogger(log_name='astrbot')
-
-def run_monitor(global_object: Context):
- '''
- 监测机器性能
- - Bot 内存使用量
- - CPU 占用率
- '''
- start_time = time.time()
- while True:
- stat = global_object.dashboard_data.stats
- # 程序占用的内存大小
- mem = psutil.Process().memory_info().rss / 1024 / 1024 # MB
- stat['sys_perf'] = {
- 'memory': mem,
- 'cpu': psutil.cpu_percent()
- }
- stat['sys_start_time'] = start_time
- time.sleep(30)
From f8949ebead0eca5f3e79335e784e8edb59f2fefb Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 11 Aug 2024 23:23:52 -0400
Subject: [PATCH 19/47] perf: reboot after installing plugin
---
dashboard/server.py | 10 ++++++----
main.py | 2 +-
model/plugin/manager.py | 44 ++++++++++++++++++++++++++++++++---------
3 files changed, 42 insertions(+), 14 deletions(-)
diff --git a/dashboard/server.py b/dashboard/server.py
index 60e880b51..a2fdc2564 100644
--- a/dashboard/server.py
+++ b/dashboard/server.py
@@ -192,10 +192,11 @@ class AstrBotDashBoard():
try:
logger.info(f"正在安装插件 {repo_url}")
self.plugin_manager.install_plugin(repo_url)
- logger.info(f"安装插件 {repo_url} 成功")
+ threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
+ logger.info(f"安装插件 {repo_url} 成功,2秒后重启")
return Response(
status="success",
- message="安装成功~",
+ message="安装成功,机器人将在 2 秒内重启。",
data=None
).__dict__
except Exception as e:
@@ -258,10 +259,11 @@ class AstrBotDashBoard():
try:
logger.info(f"正在更新插件 {plugin_name}")
self.plugin_manager.update_plugin(plugin_name)
- logger.info(f"更新插件 {plugin_name} 成功")
+ threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
+ logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response(
status="success",
- message="更新成功~",
+ message="更新成功,机器人将在 2 秒内重启。",
data=None
).__dict__
except Exception as e:
diff --git a/main.py b/main.py
index 9803ffe0d..3eead7548 100644
--- a/main.py
+++ b/main.py
@@ -53,7 +53,7 @@ if __name__ == "__main__":
check_env()
logger = LogManager.GetLogger(
- log_name='astrbot',
+ log_name='astrbot',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
diff --git a/model/plugin/manager.py b/model/plugin/manager.py
index 35fcb90e7..43e9cc30f 100644
--- a/model/plugin/manager.py
+++ b/model/plugin/manager.py
@@ -5,6 +5,7 @@ import traceback
import uuid
import shutil
import yaml
+import subprocess
from util.updator.plugin_updator import PluginUpdator
from util.io import remove_dir, download_file
@@ -84,8 +85,28 @@ class PluginManager():
def update_plugin_dept(self, path):
mirror = "https://mirrors.aliyun.com/pypi/simple/"
py = sys.executable
- os.system(f"{py} -m pip install -r {path} -i {mirror} --quiet")
-
+ # os.system(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com")
+
+ process = subprocess.Popen(f"{py} -m pip install -r {path} -i {mirror} --break-system-package --trusted-host mirrors.aliyun.com",
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True)
+
+ while True:
+ output = process.stdout.readline()
+ if output == '' and process.poll() is not None:
+ break
+ if output:
+ output = output.strip()
+ if output.startswith("Requirement already satisfied"):
+ continue
+ if output.startswith("Using cached"):
+ continue
+ if output.startswith("Looking in indexes"):
+ continue
+ logger.info(output)
+
+ rc = process.poll()
+
+
def install_plugin(self, repo_url: str):
ppath = self.plugin_store_path
@@ -95,10 +116,13 @@ class PluginManager():
plugin_path = self.updator.update(repo_url)
with open(os.path.join(plugin_path, "REPO"), "w", encoding='utf-8') as f:
f.write(repo_url)
+
+ self.check_plugin_dept_update()
- ok, err = self.plugin_reload()
- if not ok:
- raise Exception(err)
+ return plugin_path
+ # ok, err = self.plugin_reload()
+ # if not ok:
+ # raise Exception(err)
def download_from_repo_url(self, target_path: str, repo_url: str):
repo_namespace = repo_url.split("/")[-2:]
@@ -158,7 +182,7 @@ class PluginManager():
logger.info(f"正在加载插件 {root_dir_name} ...")
- # self.check_plugin_dept_update(cached_plugins, root_dir_name)
+ self.check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__("addons.plugins." +
root_dir_name + "." + p, fromlist=[p])
@@ -227,10 +251,12 @@ class PluginManager():
# remove the temp dir
remove_dir(temp_dir)
+
+ self.check_plugin_dept_update()
- ok, err = self.plugin_reload()
- if not ok:
- raise Exception(err)
+ # ok, err = self.plugin_reload()
+ # if not ok:
+ # raise Exception(err)
def load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> PluginMetadata:
metadata = None
From a578edf137211105980f4d4b9856bb49391e4b81 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Mon, 12 Aug 2024 02:50:31 -0400
Subject: [PATCH 20/47] fix: metrics perf: aiocqhttp image url
---
model/platform/__init__.py | 9 ++++++---
model/platform/qq_aiocqhttp.py | 7 +++++--
model/platform/qq_nakuru.py | 1 +
model/platform/qq_official.py | 2 +-
util/metrics.py | 3 +++
5 files changed, 16 insertions(+), 6 deletions(-)
diff --git a/model/platform/__init__.py b/model/platform/__init__.py
index 1acac0fb5..37531ce67 100644
--- a/model/platform/__init__.py
+++ b/model/platform/__init__.py
@@ -7,8 +7,9 @@ from type.astrbot_message import MessageType
class Platform():
- def __init__(self) -> None:
- pass
+ def __init__(self, platform_name: str, context) -> None:
+ self.PLATFORM_NAME = platform_name
+ self.context = context
@abc.abstractmethod
async def handle_msg(self, message: AstrBotMessage):
@@ -79,4 +80,6 @@ class Platform():
else:
rendered_images.append(Image.fromFileSystem(p))
return rendered_images
-
\ No newline at end of file
+
+ async def record_metrics(self):
+ self.context.metrics_uploader.increment_platform_stat(self.PLATFORM_NAME)
\ No newline at end of file
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index 31e025341..c70e6bef3 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -18,6 +18,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AIOCQHTTP(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
+ super().__init__("aiocqhttp", context)
self.message_handler = message_handler
self.waiting = {}
self.context = context
@@ -67,7 +68,9 @@ class AIOCQHTTP(Platform):
message_str += m['data']['text'].strip()
abm.message.append(a)
if t == 'image':
- a = Image(file=m['data']['file'])
+ file = m['data']['file'] if 'file' in m['data'] else None
+ url = m['data']['url'] if 'url' in m['data'] else None
+ a = Image(file=file, url=url)
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -195,9 +198,9 @@ class AIOCQHTTP(Platform):
await self._reply(message, res)
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
+ await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
-
ret = []
image_idx = []
for idx, segment in enumerate(message_chain):
diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py
index bf1b2419b..8b3856cd9 100644
--- a/model/platform/qq_nakuru.py
+++ b/model/platform/qq_nakuru.py
@@ -30,6 +30,7 @@ class FakeSource:
class QQGOCQ(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
+ super().__init__("nakuru", context)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index 6911cb367..147fd5b47 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -53,7 +53,7 @@ class botClient(Client):
class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
- super().__init__()
+ super().__init__("qqofficial", context)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
diff --git a/util/metrics.py b/util/metrics.py
index 172dd7ad4..59ee06633 100644
--- a/util/metrics.py
+++ b/util/metrics.py
@@ -65,6 +65,9 @@ class MetricUploader():
except BaseException as e:
pass
await asyncio.sleep(30*60)
+
+ def increment_platform_stat(self, platform_name: str):
+ self.platform_stats[platform_name] = self.platform_stats.get(platform_name, 0) + 1
def clear(self):
self.platform_stats.clear()
From 33ec92258dc506f2f8ae7ac164b34818cfe6bf1f Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Tue, 13 Aug 2024 15:05:16 +0800
Subject: [PATCH 21/47] Update config.py
---
type/config.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/type/config.py b/type/config.py
index 6d7fa437c..c5a2a60f4 100644
--- a/type/config.py
+++ b/type/config.py
@@ -1,4 +1,4 @@
-VERSION = '3.3.8'
+VERSION = '3.3.9'
DEFAULT_CONFIG = {
"qqbot": {
@@ -72,4 +72,4 @@ DEFAULT_CONFIG = {
"ws_reverse_host": "",
"ws_reverse_port": 0,
}
-}
\ No newline at end of file
+}
From 12216853c574b2fb50d85263fd8330acb643385c Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 11:20:36 +0800
Subject: [PATCH 22/47] chore: issue and pr template
---
.github/ISSUE_TEMPLATE/bug-report.yml | 82 ++++++++++++++++++++++
.github/ISSUE_TEMPLATE/feature-request.yml | 42 +++++++++++
.github/PULL_REQUEST_TEMPLATE.md | 10 +++
3 files changed, 134 insertions(+)
create mode 100644 .github/ISSUE_TEMPLATE/bug-report.yml
create mode 100644 .github/ISSUE_TEMPLATE/feature-request.yml
create mode 100644 .github/PULL_REQUEST_TEMPLATE.md
diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml
new file mode 100644
index 000000000..e7a5263c2
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug-report.yml
@@ -0,0 +1,82 @@
+name: '🐛 报告 Bug'
+title: '[Bug]'
+description: 提交报告帮助我们改进。
+labels: [ 'bug' ]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ 感谢您抽出时间报告问题!请准确解释您的问题。如果可能,请提供一个可复现的片段(这有助于更快地解决问题)。
+ - type: textarea
+ attributes:
+ label: 发生了什么
+ description: 描述你遇到的异常
+ placeholder: >
+ 一个清晰且具体的描述这个异常是什么。
+ validations:
+ required: true
+
+ - type: textarea
+ attributes:
+ label: 如何复现?
+ description: >
+ 复现该问题的步骤
+ placeholder: >
+ 如: 1. 打开 '...'
+ validations:
+ required: true
+
+ - type: textarea
+ attributes:
+ label: AstrBot 版本与部署方式
+ description: >
+ 请提供您的 AstrBot 版本和部署方式。
+ placeholder: >
+ 如: 3.1.8 Docker, 3.1.7 Windows启动器
+ validations:
+ required: true
+
+ - type: dropdown
+ attributes:
+ label: 操作系统
+ description: |
+ 你在哪个操作系统上遇到了这个问题?
+ multiple: false
+ options:
+ - 'Windows'
+ - 'macOS'
+ - 'Linux'
+ - 'Other'
+ - 'Not sure'
+ validations:
+ required: true
+
+ - type: textarea
+ attributes:
+ label: 额外信息
+ description: >
+ 任何额外信息,如报错日志、截图等。
+ placeholder: >
+ 请提供完整的报错日志或截图。
+ validations:
+ required: true
+
+ - type: checkboxes
+ attributes:
+ label: 你愿意提交 PR 吗?
+ description: >
+ 这绝对不是必需的,但我们很乐意在贡献过程中为您提供指导特别是如果你已经很好地理解了如何实现修复。
+ options:
+ - label: 是的,我愿意提交 PR!
+
+ - type: checkboxes
+ attributes:
+ label: Code of Conduct
+ options:
+ - label: >
+ 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
+ required: true
+
+ - type: markdown
+ attributes:
+ value: "感谢您填写我们的表单!"
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml
new file mode 100644
index 000000000..484959318
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature-request.yml
@@ -0,0 +1,42 @@
+
+name: '🎉 功能建议'
+title: "[Feature]"
+description: 提交建议帮助我们改进。
+labels: [ "enhancement" ]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ 感谢您抽出时间提出新功能建议,请准确解释您的想法。
+
+ - type: textarea
+ attributes:
+ label: 描述
+ description: 简短描述您的功能建议。
+
+ - type: textarea
+ attributes:
+ label: 使用场景
+ description: 你想要发生什么?
+ placeholder: >
+ 一个清晰且具体的描述这个功能的使用场景。
+
+ - type: checkboxes
+ attributes:
+ label: 你愿意提交PR吗?
+ description: >
+ 这不是必须的,但我们欢迎您的贡献。
+ options:
+ - label: 是的, 我愿意提交PR!
+
+ - type: checkboxes
+ attributes:
+ label: Code of Conduct
+ options:
+ - label: >
+ 我已阅读并同意遵守该项目的 [行为准则](https://docs.github.com/zh/site-policy/github-terms/github-community-code-of-conduct)。
+ required: true
+
+ - type: markdown
+ attributes:
+ value: "感谢您填写我们的表单!"
\ No newline at end of file
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 000000000..da603d465
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,10 @@
+
+修复了 #XYZ
+
+### Motivation
+
+
+
+### Modifications
+
+
From 970fe020271bda3743f2f6ca5e77f79aa6d88ed4 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 14:30:35 +0800
Subject: [PATCH 23/47] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DQQ=E5=AE=98?=
=?UTF-8?q?=E6=96=B9=E6=9C=BA=E5=99=A8=E4=BA=BAAPI=E8=81=8A=E5=A4=A9?=
=?UTF-8?q?=E6=97=B6=E4=B8=8D=E8=83=BD=E6=89=BE=E5=88=B0=E5=B9=B3=E5=8F=B0?=
=?UTF-8?q?=E7=9A=84=E9=97=AE=E9=A2=98=20#189?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
model/platform/qq_official.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index 147fd5b47..aa770d5f8 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -216,7 +216,7 @@ class QQOfficial(Platform):
role = 'member'
# construct astrbot message event
- ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqchan", session_id, role)
+ ame = AstrMessageEvent.from_astrbot_message(message, self.context, "qqofficial", session_id, role)
message_result = await self.message_handler.handle(ame)
if not message_result:
From e8679f89843713b6b99f97debdea8fa8239298f5 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sat, 17 Aug 2024 14:34:02 +0800
Subject: [PATCH 24/47] Create codeql.yml
---
.github/workflows/codeql.yml | 93 ++++++++++++++++++++++++++++++++++++
1 file changed, 93 insertions(+)
create mode 100644 .github/workflows/codeql.yml
diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml
new file mode 100644
index 000000000..8503bb715
--- /dev/null
+++ b/.github/workflows/codeql.yml
@@ -0,0 +1,93 @@
+# For most projects, this workflow file will not need changing; you simply need
+# to commit it to your repository.
+#
+# You may wish to alter this file to override the set of languages analyzed,
+# or to provide custom queries or build logic.
+#
+# ******** NOTE ********
+# We have attempted to detect the languages in your repository. Please check
+# the `language` matrix defined below to confirm you have the correct set of
+# supported CodeQL languages.
+#
+name: "CodeQL"
+
+on:
+ push:
+ branches: [ "master" ]
+ pull_request:
+ branches: [ "master" ]
+ schedule:
+ - cron: '21 15 * * 5'
+
+jobs:
+ analyze:
+ name: Analyze (${{ matrix.language }})
+ # Runner size impacts CodeQL analysis time. To learn more, please see:
+ # - https://gh.io/recommended-hardware-resources-for-running-codeql
+ # - https://gh.io/supported-runners-and-hardware-resources
+ # - https://gh.io/using-larger-runners (GitHub.com only)
+ # Consider using larger runners or machines with greater resources for possible analysis time improvements.
+ runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
+ timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }}
+ permissions:
+ # required for all workflows
+ security-events: write
+
+ # required to fetch internal or private CodeQL packs
+ packages: read
+
+ # only required for workflows in private repositories
+ actions: read
+ contents: read
+
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - language: python
+ build-mode: none
+ # CodeQL supports the following values keywords for 'language': 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
+ # Use `c-cpp` to analyze code written in C, C++ or both
+ # Use 'java-kotlin' to analyze code written in Java, Kotlin or both
+ # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
+ # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
+ # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
+ # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
+ # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ # Initializes the CodeQL tools for scanning.
+ - name: Initialize CodeQL
+ uses: github/codeql-action/init@v3
+ with:
+ languages: ${{ matrix.language }}
+ build-mode: ${{ matrix.build-mode }}
+ # If you wish to specify custom queries, you can do so here or in a config file.
+ # By default, queries listed here will override any specified in a config file.
+ # Prefix the list here with "+" to use these queries and those in the config file.
+
+ # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
+ # queries: security-extended,security-and-quality
+
+ # If the analyze step fails for one of the languages you are analyzing with
+ # "We were unable to automatically build your code", modify the matrix above
+ # to set the build mode to "manual" for that language. Then modify this step
+ # to build your code.
+ # ℹ️ Command-line programs to run using the OS shell.
+ # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
+ - if: matrix.build-mode == 'manual'
+ shell: bash
+ run: |
+ echo 'If you are using a "manual" build mode for one or more of the' \
+ 'languages you are analyzing, replace this with the commands to build' \
+ 'your code, for example:'
+ echo ' make bootstrap'
+ echo ' make release'
+ exit 1
+
+ - name: Perform CodeQL Analysis
+ uses: github/codeql-action/analyze@v3
+ with:
+ category: "/language:${{matrix.language}}"
From e21736b470381473189ac8bc58a17a6bdb610537 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 14:54:11 +0800
Subject: [PATCH 25/47] perf: remove message reply when rate limit occur
---
astrbot/message/handler.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index 5d460182f..7b455b4d9 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -138,9 +138,11 @@ class MessageHandler():
# return MessageResult("Hi~")
# check the rate limit
- if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
- return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
-
+ if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
+ # return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
+ logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
+ return
+
# remove the nick prefix
for nick in self.nicks:
if msg_plain.startswith(nick):
From 8a5877291125110945638ba27d1fa40ad078bdd8 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 14:58:43 +0800
Subject: [PATCH 26/47] perf: fill the missing metric record
---
model/platform/qq_nakuru.py | 1 +
model/platform/qq_official.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py
index 8b3856cd9..e904d3403 100644
--- a/model/platform/qq_nakuru.py
+++ b/model/platform/qq_nakuru.py
@@ -192,6 +192,7 @@ class QQGOCQ(Platform):
await self._reply(source, res)
async def _reply(self, source, message_chain: List[BaseMessageComponent]):
+ await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index aa770d5f8..cc9762c09 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -321,6 +321,7 @@ class QQOfficial(Platform):
return await self._reply(**data)
async def _reply(self, **kwargs):
+ await self.record_metrics()
if 'group_openid' in kwargs or 'openid' in kwargs:
# QQ群组消息
if 'file_image' in kwargs and kwargs['file_image']:
From 107214ac537bd72b83ae2b13906cc38ec803c076 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 15:01:55 +0800
Subject: [PATCH 27/47] fix: Handle errors in AstrBotBootstrap gracefully
---
astrbot/bootstrap.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py
index 225ab08a9..36c0f010f 100644
--- a/astrbot/bootstrap.py
+++ b/astrbot/bootstrap.py
@@ -105,8 +105,8 @@ class AstrBotBootstrap():
return
except Exception as e:
logger.error(traceback.format_exc())
- logger.error(f"{task.get_name()} 任务发生错误,将在 5 秒后重试。")
- await asyncio.sleep(5)
+ logger.error(f"{task.get_name()} 任务发生错误。")
+ return
def load_llm(self):
if 'openai' in self.config_helper.cached_config and \
From 6992249e533400d6921a4321fd7c4493c0a825ac Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 15:06:13 +0800
Subject: [PATCH 28/47] refactor: Update image downloading method in
ProviderOpenAIOfficial
---
model/provider/openai_official.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index 4914b9a8c..f3e9ab8e8 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -11,6 +11,7 @@ from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
+from util.io import download_image_by_url
from astrbot.persist.helper import dbConn
from model.provider.provider import Provider
@@ -152,7 +153,7 @@ class ProviderOpenAIOfficial(Provider):
将图片转换为 base64
'''
if image_url.startswith("http"):
- image_url = await gu.download_image_by_url(image_url)
+ image_url = await download_image_by_url(image_url)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode()
From 32e2a7830a20eb9d262e4a75094135eedeebb00b Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 03:20:08 -0400
Subject: [PATCH 29/47] feat: Add timeout parameter to QQOfficial bot client
initialization
---
model/platform/qq_official.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index cc9762c09..d0420263c 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -81,7 +81,8 @@ class QQOfficial(Platform):
)
self.client = botClient(
intents=self.intents,
- bot_log=False
+ bot_log=False,
+ timeout=20,
)
self.client.set_platform(self)
@@ -178,7 +179,8 @@ class QQOfficial(Platform):
logger.error(traceback.format_exc())
self.client = botClient(
intents=self.intents,
- bot_log=False
+ bot_log=False,
+ timeout=20,
)
self.client.set_platform(self)
return self.client.start(
From dcec3f5f8421cd0af5a44f2749ea57e8deed5976 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 04:46:23 -0400
Subject: [PATCH 30/47] feat: unit test perf: func call improvement
---
.coveragerc | 5 +++
astrbot/bootstrap.py | 7 +++-
astrbot/message/handler.py | 19 +++++----
model/platform/qq_aiocqhttp.py | 3 ++
model/platform/qq_official.py | 4 +-
model/provider/openai_official.py | 3 ++
tests/mocks/onebot.py | 13 +++++++
tests/mocks/qq_official.py | 45 +++++++++++++++++++++
tests/test_message.py | 65 +++++++++++++++++++++++++++++++
type/types.py | 19 +++++++--
util/agent/func_call.py | 15 ++++++-
11 files changed, 182 insertions(+), 16 deletions(-)
create mode 100644 .coveragerc
create mode 100644 tests/mocks/onebot.py
create mode 100644 tests/mocks/qq_official.py
create mode 100644 tests/test_message.py
diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 000000000..1385093f4
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,5 @@
+[run]
+omit =
+ */site-packages/*
+ */dist-packages/*
+ your_package_name/tests/*
\ No newline at end of file
diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py
index 36c0f010f..0eb376769 100644
--- a/astrbot/bootstrap.py
+++ b/astrbot/bootstrap.py
@@ -22,7 +22,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap():
- def __init__(self) -> None:
+ def __init__(self) -> None:
self.context = Context()
self.config_helper = CmdConfig()
@@ -57,6 +57,8 @@ class AstrBotBootstrap():
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
else:
logger.info("未使用代理。")
+
+ self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
async def run(self):
self.command_manager = CommandManager()
@@ -80,6 +82,9 @@ class AstrBotBootstrap():
self.context.message_handler = self.message_handler
self.context.command_manager = self.command_manager
+ if self.test_mode:
+ return
+
# load plugins, plugins' commands.
self.load_plugins()
self.command_manager.register_from_pcb(self.context.plugin_command_bridge)
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index 7b455b4d9..c3d26305e 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -1,5 +1,5 @@
import time
-import re
+import re, os
import asyncio
import traceback
import astrbot.message.unfit_words as uw
@@ -14,6 +14,7 @@ from type.command import CommandResult
from SparkleLogging.utils.core import LogManager
from logging import Logger
from nakuru.entities.components import Image
+from util.agent.func_call import FuncCall
import util.agent.web_searcher as web_searcher
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -117,6 +118,8 @@ class MessageHandler():
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = str(self.context.reply_prefix)
+
+ self.llm_tools = FuncCall(self.provider)
def set_provider(self, provider: Provider):
self.provider = provider
@@ -128,21 +131,20 @@ class MessageHandler():
`llm_provider`: the provider to use for LLM. If None, use the default provider
'''
msg_plain = message.message_str.strip()
- provider = llm_provider if llm_provider else self.provider
- inner_provider = False if llm_provider else True
+ provider = llm_provider if llm_provider else self.provider
- self.persist_manager.record_message(message.platform.platform_name, message.session_id)
+ if os.environ.get('TEST_MODE', 'off') != 'on':
+ self.persist_manager.record_message(message.platform.platform_name, message.session_id)
# TODO: this should be configurable
# if not message.message_str:
# return MessageResult("Hi~")
# check the rate limit
- if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
- # return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
- logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
+ if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
+ logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。")
return
-
+
# remove the nick prefix
for nick in self.nicks:
if msg_plain.startswith(nick):
@@ -183,6 +185,7 @@ class MessageHandler():
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
break
+
web_search = self.context.web_search
if not web_search and msg_plain.startswith("ws"):
# leverage web search feature
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index c70e6bef3..49299201d 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -210,6 +210,9 @@ class AIOCQHTTP(Platform):
if isinstance(segment, Image):
image_idx.append(idx)
ret.append(d)
+ if os.environ.get('TEST_MODE', 'off') == 'on':
+ logger.info(f"回复消息: {ret}")
+ return
try:
if isinstance(message, AstrBotMessage):
await self.bot.send(message.raw_message, ret)
diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py
index d0420263c..0ca3aeea4 100644
--- a/model/platform/qq_official.py
+++ b/model/platform/qq_official.py
@@ -52,7 +52,7 @@ class botClient(Client):
class QQOfficial(Platform):
- def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
+ def __init__(self, context: Context, message_handler: MessageHandler) -> None:
super().__init__("qqofficial", context)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
@@ -87,7 +87,7 @@ class QQOfficial(Platform):
self.client.set_platform(self)
- self.test_mode = test_mode
+ self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
plain_text = ""
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index f3e9ab8e8..7ecbeab7d 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -296,6 +296,9 @@ class ProviderOpenAIOfficial(Provider):
extra_conf: Dict = None,
**kwargs
) -> str:
+ if os.environ.get("TEST_LLM", "off") == "on":
+ return "这是一个测试消息。"
+
super().accu_model_stat()
if not session_id:
session_id = "unknown"
diff --git a/tests/mocks/onebot.py b/tests/mocks/onebot.py
new file mode 100644
index 000000000..66df3d1ee
--- /dev/null
+++ b/tests/mocks/onebot.py
@@ -0,0 +1,13 @@
+from aiocqhttp import Event
+
+class MockOneBotMessage():
+ def __init__(self):
+ # 这些数据不是敏感的
+ self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470})
+ self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'})
+
+ def create_random_group_message(self):
+ return self.group_event_sample
+
+ def create_random_direct_message(self):
+ return self.friend_event_sample
\ No newline at end of file
diff --git a/tests/mocks/qq_official.py b/tests/mocks/qq_official.py
new file mode 100644
index 000000000..0d665d289
--- /dev/null
+++ b/tests/mocks/qq_official.py
@@ -0,0 +1,45 @@
+import botpy.message
+
+class MockQQOfficialMessage():
+ def __init__(self):
+ # 这些数据已经经过去敏处理
+ self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'}
+ self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'}
+ self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'}
+ self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
+
+ self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
+ self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
+ self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
+ self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
+
+ self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
+ self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
+ self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
+ self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
+
+ def create_random_group_message(self):
+ mocked = botpy.message.GroupMessage(
+ api=None,
+ event_id=self.group_event_id_sample,
+ data=self.group_plain_text_sample
+ )
+ return mocked
+
+ def create_random_guild_message(self):
+ mocked = botpy.message.Message(
+ api=None,
+ event_id=self.guild_event_id_sample,
+ data=self.guild_plain_text_sample
+ )
+ return mocked
+
+ def create_random_direct_message(self):
+ mocked = botpy.message.DirectMessage(
+ api=None,
+ event_id=self.direct_event_id_sample,
+ data=self.direct_plain_text_sample
+ )
+ return mocked
+
+
diff --git a/tests/test_message.py b/tests/test_message.py
new file mode 100644
index 000000000..a5fc4578a
--- /dev/null
+++ b/tests/test_message.py
@@ -0,0 +1,65 @@
+import asyncio
+import pytest
+import os
+
+from tests.mocks.qq_official import MockQQOfficialMessage
+from tests.mocks.onebot import MockOneBotMessage
+
+from astrbot.bootstrap import AstrBotBootstrap
+from model.platform.qq_official import QQOfficial
+from model.platform.qq_aiocqhttp import AIOCQHTTP
+from type.astrbot_message import *
+from type.message_event import *
+from SparkleLogging.utils.core import LogManager
+from logging import Formatter
+
+logger = LogManager.GetLogger(
+log_name='astrbot',
+ out_to_console=True,
+ custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
+)
+pytest_plugins = ('pytest_asyncio',)
+
+os.environ['TEST_MODE'] = 'on'
+bootstrap = AstrBotBootstrap()
+asyncio.run(bootstrap.run())
+
+qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler)
+aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler)
+
+class TestBasicMessageHandle():
+ @pytest.mark.asyncio
+ async def test_qqofficial_group_message(self):
+ group_message = MockQQOfficialMessage().create_random_group_message()
+ abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE)
+ ret = await qq_official.handle_msg(abm)
+ print(ret)
+
+ @pytest.mark.asyncio
+ async def test_qqofficial_guild_message(self):
+ guild_message = MockQQOfficialMessage().create_random_guild_message()
+ abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE)
+ ret = await qq_official.handle_msg(abm)
+ print(ret)
+
+ # 有共同性,为了节约开销,不测试频道私聊。
+ # @pytest.mark.asyncio
+ # async def test_qqofficial_private_message(self):
+ # private_message = MockQQOfficialMessage().create_random_direct_message()
+ # abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE)
+ # ret = await qq_official.handle_msg(abm)
+ # print(ret)
+
+ @pytest.mark.asyncio
+ async def test_aiocqhttp_group_message(self):
+ event = MockOneBotMessage().create_random_group_message()
+ abm = aiocqhttp.convert_message(event)
+ ret = await aiocqhttp.handle_msg(abm)
+ print(ret)
+
+ @pytest.mark.asyncio
+ async def test_aiocqhttp_direct_message(self):
+ event = MockOneBotMessage().create_random_direct_message()
+ abm = aiocqhttp.convert_message(event)
+ ret = await aiocqhttp.handle_msg(abm)
+ print(ret)
\ No newline at end of file
diff --git a/type/types.py b/type/types.py
index 06aad685f..fd4d07d9c 100644
--- a/type/types.py
+++ b/type/types.py
@@ -1,4 +1,4 @@
-import asyncio
+import asyncio, os
from asyncio import Task
from type.register import *
from typing import List, Awaitable
@@ -12,6 +12,7 @@ from type.command import CommandResult
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
+from util.agent.func_call import FuncCall
class Context:
@@ -97,13 +98,25 @@ class Context:
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
'''
self.llms.append(RegisteredLLM(llm_name, provider, origin))
+
+ def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable):
+ '''
+ 为函数调用(function-calling / tools-use)添加工具。
+
+ @param name: 函数名
+ @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
+ @param desc: 函数描述
+ @param func_obj: 处理函数
+ '''
+ self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
if platform_name == platform.platform_name:
return platform
-
- raise ValueError("couldn't find the platform you specified")
+
+ if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错
+ raise ValueError("couldn't find the platform you specified")
async def send_message(self, unified_msg_origin: str, message: CommandResult):
'''
diff --git a/util/agent/func_call.py b/util/agent/func_call.py
index 5bb602781..e805f5bfc 100644
--- a/util/agent/func_call.py
+++ b/util/agent/func_call.py
@@ -25,6 +25,14 @@ class FuncCall():
self.provider = provider
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
+ '''
+ 为函数调用(function-calling / tools-use)添加工具。
+
+ @param name: 函数名
+ @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
+ @param desc: 函数描述
+ @param func_obj: 处理函数
+ '''
params = {
"type": "object", # hardcore here
"properties": {}
@@ -65,7 +73,10 @@ class FuncCall():
})
return _l
- async def func_call(self, question: str, func_definition: str, session_id: str=None):
+ async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple:
+
+ if not provider:
+ provider = self.provider
prompt = textwrap.dedent(f"""
ROLE:
@@ -91,7 +102,7 @@ class FuncCall():
_c = 0
while _c < 3:
try:
- res = await self.provider.text_chat(prompt, session_id)
+ res = await provider.text_chat(prompt, session_id)
print(res)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]
From 12d37381feaa52016b268310408af87d98e0bb98 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 04:49:43 -0400
Subject: [PATCH 31/47] perf: request llm api when only TEST_LLM=on
---
model/provider/openai_official.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index 7ecbeab7d..c428af81f 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -296,7 +296,7 @@ class ProviderOpenAIOfficial(Provider):
extra_conf: Dict = None,
**kwargs
) -> str:
- if os.environ.get("TEST_LLM", "off") == "on":
+ if os.environ.get("TEST_LLM", "off") != "on":
return "这是一个测试消息。"
super().accu_model_stat()
From 48c12634178d7067151b91a55d4d9df4aa626af9 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:02:34 -0400
Subject: [PATCH 32/47] chore: add coverage test workflow
---
.github/workflows/coverage_test.yml | 33 +++++++++++++++++++++++++++++
1 file changed, 33 insertions(+)
create mode 100644 .github/workflows/coverage_test.yml
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
new file mode 100644
index 000000000..8dd12018f
--- /dev/null
+++ b/.github/workflows/coverage_test.yml
@@ -0,0 +1,33 @@
+name: Run tests and upload coverage
+
+on:
+ push
+
+jobs:
+ test:
+ name: Run tests and collect coverage
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Write secret to file
+ run: mkdir data & echo "$CMD_CONFIG" > data/cmd_config.json
+ env:
+ MY_SECRET: ${{ secrets.CMD_CONFIG }}
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+
+ - name: Install dependencies
+ run: pip install pytest pytest-cov pytest-asyncio
+
+ - name: Run tests
+ run: PYTHONPATH=./ pytest --cov=. tests/ -v
+
+ - name: Upload results to Codecov
+ uses: codecov/codecov-action@v4
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
\ No newline at end of file
From 67792100bbfed25d8e1db84e8f098001d0f137bf Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:08:08 -0400
Subject: [PATCH 33/47] refactor: Fix command configuration file creation in
coverage test workflow
---
.github/workflows/coverage_test.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index 8dd12018f..50b3ab640 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -14,7 +14,7 @@ jobs:
fetch-depth: 0
- name: Write secret to file
- run: mkdir data & echo "$CMD_CONFIG" > data/cmd_config.json
+ run: mkdir data && echo "$CMD_CONFIG" > data/cmd_config.json
env:
MY_SECRET: ${{ secrets.CMD_CONFIG }}
From dcf96896efb72f9e5900b7591f67168056910f7c Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:10:05 -0400
Subject: [PATCH 34/47] chore: Update coverage test workflow to install
dependencies from requirements.txt
---
.github/workflows/coverage_test.yml | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index 50b3ab640..82e82c23f 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -20,9 +20,12 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
-
+
- name: Install dependencies
- run: pip install pytest pytest-cov pytest-asyncio
+ run: |
+ python -m pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install pytest pytest-cov pytest-asyncio
- name: Run tests
run: PYTHONPATH=./ pytest --cov=. tests/ -v
From ae6dd8929a8c5c8b8c454b10f0bd6c51eb866d15 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:25:45 -0400
Subject: [PATCH 35/47] refactor: Update coverage test workflow to create
command configuration file properly
---
.github/workflows/coverage_test.yml | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index 82e82c23f..987f99aee 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -14,9 +14,12 @@ jobs:
fetch-depth: 0
- name: Write secret to file
- run: mkdir data && echo "$CMD_CONFIG" > data/cmd_config.json
env:
- MY_SECRET: ${{ secrets.CMD_CONFIG }}
+ MY_SECRET: ${{ secrets.CMD_CONFIG }}
+ run: |
+ mkdir data
+ touch data/cmd_config.json
+ echo "$CMD_CONFIG" > data/cmd_config.json
- name: Set up Python
uses: actions/setup-python@v4
From d3a5205bde30b534dcd89e74295673d622d58138 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:27:33 -0400
Subject: [PATCH 36/47] refactor: Update coverage test workflow to properly
create command configuration file
---
.github/workflows/coverage_test.yml | 8 --------
1 file changed, 8 deletions(-)
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index 987f99aee..ab5e3c42a 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -13,14 +13,6 @@ jobs:
with:
fetch-depth: 0
- - name: Write secret to file
- env:
- MY_SECRET: ${{ secrets.CMD_CONFIG }}
- run: |
- mkdir data
- touch data/cmd_config.json
- echo "$CMD_CONFIG" > data/cmd_config.json
-
- name: Set up Python
uses: actions/setup-python@v4
From 743046d48f6042b132d879ecef2ba19aa5782f46 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:29:52 -0400
Subject: [PATCH 37/47] chore: Create necessary directories for data and temp
in coverage test workflow
---
.github/workflows/coverage_test.yml | 3 +++
1 file changed, 3 insertions(+)
diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index ab5e3c42a..a021daa7c 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -21,6 +21,9 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
+ mkdir data
+ mkdir data/config
+ mkdir temp
- name: Run tests
run: PYTHONPATH=./ pytest --cov=. tests/ -v
From 2eeb5822c1f6e369d2fb60d9cf79b5674c98b381 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 05:54:38 -0400
Subject: [PATCH 38/47] chore: add codecov.yml
---
.codecov.yml | 4 ++++
1 file changed, 4 insertions(+)
create mode 100644 .codecov.yml
diff --git a/.codecov.yml b/.codecov.yml
new file mode 100644
index 000000000..e9113832b
--- /dev/null
+++ b/.codecov.yml
@@ -0,0 +1,4 @@
+comment:
+ # add "condensed_" to "header", "files" and "footer"
+ layout: "condensed_header, condensed_files, condensed_footer"
+ hide_project_coverage: TRUE # set to true
\ No newline at end of file
From a9c6a68c5fe9021f166f688c085e4ea8d2046ad5 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sat, 17 Aug 2024 17:59:59 +0800
Subject: [PATCH 39/47] Update README.md
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 377037db1..de6d25467 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@
[](https://github.com/Soulter/AstrBot/releases/latest)
+[](https://codecov.io/gh/Soulter/AstrBot)
From d3b0f25cfe26c6e4afb9c2c234a01c2dcaee5f11 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sat, 17 Aug 2024 06:19:08 -0400
Subject: [PATCH 40/47] refactor: Update ProviderOpenAIOfficial to skip test
message when TEST_MODE=on
This commit updates the `ProviderOpenAIOfficial` class to skip returning the test message when the environment variable `TEST_MODE` is set to "on". This change ensures that the test message is only returned when both `TEST_LLM` and `TEST_MODE` are set to "on".
---
model/provider/openai_official.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py
index c428af81f..981f33a6b 100644
--- a/model/provider/openai_official.py
+++ b/model/provider/openai_official.py
@@ -296,7 +296,7 @@ class ProviderOpenAIOfficial(Provider):
extra_conf: Dict = None,
**kwargs
) -> str:
- if os.environ.get("TEST_LLM", "off") != "on":
+ if os.environ.get("TEST_LLM", "off") != "on" and os.environ.get("TEST_MODE", "off") == "on":
return "这是一个测试消息。"
super().accu_model_stat()
From a5db4d4e473cc0a5728ed30223741e363a45373a Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 18 Aug 2024 03:55:11 -0400
Subject: [PATCH 41/47] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=BC=82?=
=?UTF-8?q?=E7=AB=AF=E6=83=85=E5=86=B5=E4=B8=8B=E4=B8=BB=E5=8A=A8=E4=BF=A1?=
=?UTF-8?q?=E6=81=AF=E5=8F=91=E9=80=81=E5=B8=A6=E6=9C=89=E6=9C=AC=E5=9C=B0?=
=?UTF-8?q?=E5=9B=BE=E7=89=87url=E7=9A=84=E6=B6=88=E6=81=AF=E6=97=B6?=
=?UTF-8?q?=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
model/platform/qq_aiocqhttp.py | 45 ++++++++++++++++++++--------------
1 file changed, 27 insertions(+), 18 deletions(-)
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index 49299201d..53e245845 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -80,7 +80,7 @@ class AIOCQHTTP(Platform):
def run_aiocqhttp(self):
if not self.host or not self.port:
return
- self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp')
+ self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=20)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
@@ -176,9 +176,6 @@ class AIOCQHTTP(Platform):
"""
回复用户唤醒机器人的消息。(被动回复)
"""
- logger.info(
- f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
-
res = result_message
if isinstance(res, str):
@@ -201,6 +198,13 @@ class AIOCQHTTP(Platform):
await self.record_metrics()
if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ]
+
+ if isinstance(message, AstrBotMessage):
+ logger.info(
+ f"{message.sender.user_id} <- {self.parse_message_outline(message)}")
+ else:
+ logger.info(f"回复消息: {message_chain}")
+
ret = []
image_idx = []
for idx, segment in enumerate(message_chain):
@@ -214,23 +218,13 @@ class AIOCQHTTP(Platform):
logger.info(f"回复消息: {ret}")
return
try:
- if isinstance(message, AstrBotMessage):
- await self.bot.send(message.raw_message, ret)
- if isinstance(message, dict):
- if 'group_id' in message:
- await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
- elif 'user_id' in message:
- await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
- else:
- raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
+ await self._reply_wrapper(message, ret)
except ActionFailed as e:
- logger.error(traceback.format_exc())
- logger.error(f"回复消息失败: {e}")
if e.retcode == 1200:
# ENOENT
if not image_idx:
raise e
- logger.info("检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
+ logger.warn("回复失败。检测到失败原因为文件未找到,猜测用户的协议端与 AstrBot 位于不同的文件系统上。尝试采用上传图片的方式发图。")
for idx in image_idx:
if ret[idx]['data']['file'].startswith('file://'):
logger.info(f"正在上传图片: {ret[idx]['data']['path']}")
@@ -238,8 +232,23 @@ class AIOCQHTTP(Platform):
logger.info(f"上传成功。")
ret[idx]['data']['file'] = image_url
ret[idx]['data']['path'] = image_url
- await self.bot.send(message.raw_message, ret)
-
+ await self._reply_wrapper(message, ret)
+ else:
+ logger.error(traceback.format_exc())
+ logger.error(f"回复消息失败: {e}")
+ raise e
+
+ async def _reply_wrapper(self, message: Union[AstrBotMessage, Dict], ret: List):
+ if isinstance(message, AstrBotMessage):
+ await self.bot.send(message.raw_message, ret)
+ if isinstance(message, dict):
+ if 'group_id' in message:
+ await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
+ elif 'user_id' in message:
+ await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
+ else:
+ raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
+
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
'''
以主动的方式给QQ用户、QQ群发送一条消息。
From 3b3f75f03e0ff2c20b89a8d2172377267e82bb66 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 18 Aug 2024 04:00:45 -0400
Subject: [PATCH 42/47] =?UTF-8?q?fix:=20=E5=A2=9E=E5=A4=A7=E8=B6=85?=
=?UTF-8?q?=E6=97=B6=E6=97=B6=E9=97=B4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
model/platform/qq_aiocqhttp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py
index 53e245845..ec5f2b57e 100644
--- a/model/platform/qq_aiocqhttp.py
+++ b/model/platform/qq_aiocqhttp.py
@@ -80,7 +80,7 @@ class AIOCQHTTP(Platform):
def run_aiocqhttp(self):
if not self.host or not self.port:
return
- self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=20)
+ self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
From b5cb5eb9699b7ef1e08f82af19e35bd5002e5cf6 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 8 Sep 2024 19:41:00 +0800
Subject: [PATCH 43/47] feat: customized tool-use
---
astrbot/message/handler.py | 77 ++++++++++++++++----
model/command/internal_handler.py | 21 ++++++
type/types.py | 13 +++-
util/agent/func_call.py | 20 ++++--
util/agent/web_searcher.py | 112 ++++--------------------------
5 files changed, 129 insertions(+), 114 deletions(-)
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index c3d26305e..89383a463 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -1,4 +1,4 @@
-import time
+import time, json
import re, os
import asyncio
import traceback
@@ -16,6 +16,8 @@ from logging import Logger
from nakuru.entities.components import Image
from util.agent.func_call import FuncCall
import util.agent.web_searcher as web_searcher
+from openai._exceptions import *
+from openai.types.chat.chat_completion_message_tool_call import Function
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -186,31 +188,82 @@ class MessageHandler():
image_url = comp.url if comp.url else comp.file
break
- web_search = self.context.web_search
- if not web_search and msg_plain.startswith("ws"):
- # leverage web search feature
- web_search = True
- msg_plain = msg_plain.removeprefix("ws").strip()
-
+ # web_search = self.context.web_search
+ # if not web_search and msg_plain.startswith("ws"):
+ # # leverage web search feature
+ # web_search = True
+ # msg_plain = msg_plain.removeprefix("ws").strip()
try:
- if web_search:
- llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, official_fc=True)
+ if not self.llm_tools.empty():
+ # tools-use
+ tool_use_flag = True
+ llm_result = await provider.text_chat(
+ prompt=msg_plain,
+ session_id=message.session_id,
+ tools=self.llm_tools.get_func()
+ )
+
+ if isinstance(llm_result, Function):
+ logger.debug(f"function-calling: {llm_result}")
+ func_obj = None
+ for i in self.llm_tools.func_list:
+ if i["name"] == llm_result.name:
+ func_obj = i["func_obj"]
+ break
+ if not func_obj:
+ return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。")
+ try:
+ args = json.loads(llm_result.arguments)
+ function_invoked_ret = await func_obj(**args)
+ has_func = True
+ except BaseException as e:
+ traceback.print_exc()
+ return MessageResult("AstrBot Function-calling 异常:" + str(e))
+ else:
+ return MessageResult(llm_result)
+
else:
+ # normal chat
+ tool_use_flag = False
llm_result = await provider.text_chat(
prompt=msg_plain,
session_id=message.session_id,
image_url=image_url
)
+ except BadRequestError as e:
+ if tool_use_flag:
+ # seems like the model don't support function-calling
+ logger.error(f"error: {e}. Using local function-calling implementation")
+
+ try:
+ # use local function-calling implementation
+ args = {
+ 'question': llm_result,
+ 'func_definition': self.llm_tools.func_dump(),
+ }
+ _, has_func = await self.llm_tools.func_call(**args)
+
+ if not has_func:
+ # normal chat
+ llm_result = await provider.text_chat(
+ prompt=msg_plain,
+ session_id=message.session_id,
+ image_url=image_url
+ )
+ except BaseException as e:
+ logger.error(traceback.format_exc())
+ return CommandResult("AstrBot Function-calling 异常:" + str(e))
+
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"LLM 调用失败。")
return MessageResult("AstrBot 请求 LLM 资源失败:" + str(e))
-
- # concatenate the reply prefix
+
+ # concatenate reply prefix
if self.reply_prefix:
llm_result = self.reply_prefix + llm_result
- # mask the unsafe content
+ # mask unsafe content
llm_result = self.content_safety_helper.filter_content(llm_result)
check = self.content_safety_helper.baidu_check(llm_result)
if not check:
diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py
index 422c226f8..9d16f8dca 100644
--- a/model/command/internal_handler.py
+++ b/model/command/internal_handler.py
@@ -9,6 +9,7 @@ from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from nakuru.entities.components import Image
+from util.agent.web_searcher import search_from_bing, fetch_website_content
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -212,6 +213,23 @@ class InternalCommandHandler:
)
elif l[1] == 'on':
context.web_search = True
+ context.register_llm_tool("web_search", [{
+ "type": "string",
+ "name": "keyword",
+ "description": "搜索关键词"
+ }],
+ "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
+ search_from_bing
+ )
+ context.register_llm_tool("fetch_website_content", [{
+ "type": "string",
+ "name": "url",
+ "description": "要获取内容的网页链接"
+ }],
+ "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
+ fetch_website_content
+ )
+
return CommandResult(
hit=True,
success=True,
@@ -219,6 +237,9 @@ class InternalCommandHandler:
)
elif l[1] == 'off':
context.web_search = False
+ context.unregister_llm_tool("web_search")
+ context.unregister_llm_tool("fetch_website_content")
+
return CommandResult(
hit=True,
success=True,
diff --git a/type/types.py b/type/types.py
index fd4d07d9c..542ab4fbf 100644
--- a/type/types.py
+++ b/type/types.py
@@ -110,6 +110,12 @@ class Context:
'''
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
+ def unregister_llm_tool(self, tool_name: str):
+ '''
+ 删除一个函数调用工具。
+ '''
+ self.message_handler.llm_tools.remove_func(tool_name)
+
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
if platform_name == platform.platform_name:
@@ -131,4 +137,9 @@ class Context:
platform_name, message_type, id = l
platform = self.find_platform(platform_name)
await platform.platform_instance.send_msg_new(MessageType(message_type), id, message)
-
\ No newline at end of file
+
+ def get_current_llm_provider(self) -> Provider:
+ '''
+ 获取当前的 LLM Provider。
+ '''
+ return self.message_handler.provider
\ No newline at end of file
diff --git a/util/agent/func_call.py b/util/agent/func_call.py
index e805f5bfc..5283ee4d6 100644
--- a/util/agent/func_call.py
+++ b/util/agent/func_call.py
@@ -23,6 +23,9 @@ class FuncCall():
def __init__(self, provider: Provider) -> None:
self.func_list = []
self.provider = provider
+
+ def empty(self) -> bool:
+ return len(self.func_list) == 0
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
'''
@@ -34,7 +37,7 @@ class FuncCall():
@param func_obj: 处理函数
'''
params = {
- "type": "object", # hardcore here
+ "type": "object", # hard-coded here
"properties": {}
}
for param in func_args:
@@ -42,14 +45,23 @@ class FuncCall():
"type": param['type'],
"description": param['description']
}
- self._func = {
+ _func = {
"name": name,
"parameters": params,
"description": desc,
"func_obj": func_obj,
}
- self.func_list.append(self._func)
-
+ self.func_list.append(_func)
+
+ def remove_func(self, name: str) -> None:
+ '''
+ 删除一个函数调用工具。
+ '''
+ for i, f in enumerate(self.func_list):
+ if f["name"] == name:
+ self.func_list.pop(i)
+ break
+
def func_dump(self) -> str:
_l = []
for f in self.func_list:
diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py
index c8634c2b7..e519ca697 100644
--- a/util/agent/web_searcher.py
+++ b/util/agent/web_searcher.py
@@ -16,6 +16,8 @@ from util.websearch.google import Google
from model.provider.provider import Provider
from SparkleLogging.utils.core import LogManager
from logging import Logger
+from type.types import Context
+from type.message_event import AstrMessageEvent
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -31,24 +33,7 @@ def tidy_text(text: str) -> str:
'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
-# def special_fetch_zhihu(link: str) -> str:
-# '''
-# function-calling 函数, 用于获取知乎文章的内容
-# '''
-# response = requests.get(link, headers=HEADERS)
-# response.encoding = "utf-8"
-# soup = BeautifulSoup(response.text, "html.parser")
-
-# if "zhuanlan.zhihu.com" in link:
-# r = soup.find(class_="Post-RichTextContainer")
-# else:
-# r = soup.find(class_="List-item").find(class_="RichContent-inner")
-# if r is None:
-# print("debug: zhihu none")
-# raise Exception("zhihu none")
-# return tidy_text(r.text)
-
-async def search_from_bing(keyword: str) -> str:
+async def search_from_bing(context: Context, ame: AstrMessageEvent, keyword: str) -> str:
'''
tools, 从 bing 搜索引擎搜索
'''
@@ -84,10 +69,11 @@ async def search_from_bing(keyword: str) -> str:
site_result = site_result[:600] + "..." if len(site_result) > 600 else site_result
ret += f"{idx}. {i.title} \n{i.snippet}\n{site_result}\n\n"
idx += 1
- return ret
+
+ return await summarize(context, ame, ret)
-async def fetch_website_content(url):
+async def fetch_website_content(context: Context, ame: AstrMessageEvent, url: str):
header = HEADERS
header.update({'User-Agent': random.choice(USER_AGENTS)})
async with aiohttp.ClientSession() as session:
@@ -97,81 +83,13 @@ async def fetch_website_content(url):
ret = doc.summary(html_partial=True)
soup = BeautifulSoup(ret, 'html.parser')
ret = tidy_text(soup.get_text())
- return ret
-
-
-async def web_search(prompt: str, provider: Provider, session_id: str, official_fc: bool=False):
- '''
- @param official_fc: 使用官方 function-calling
- '''
- new_func_call = FuncCall(provider)
-
- new_func_call.add_func("web_search", [{
- "type": "string",
- "name": "keyword",
- "description": "搜索关键词"
- }],
- "通过搜索引擎搜索。如果问题需要获取近期、实时的消息,在网页上搜索(如天气、新闻或任何需要通过网页获取信息的问题),则调用此函数;如果没有,不要调用此函数。",
- search_from_bing
- )
- new_func_call.add_func("fetch_website_content", [{
- "type": "string",
- "name": "url",
- "description": "要获取内容的网页链接"
- }],
- "获取网页的内容。如果问题带有合法的网页链接并且用户有需求了解网页内容(例如: `帮我总结一下 https://github.com 的内容`), 就调用此函数。如果没有,不要调用此函数。",
- fetch_website_content
- )
+ return await summarize(context, ame, ret)
- has_func = False
- function_invoked_ret = ""
- if official_fc:
- # we use official function-calling
- try:
- result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
- except BadRequestError as e:
- # seems dont support function-calling
- logger.error(f"error: {e}. Try to use local function-calling implementation")
- return await web_search(prompt, provider, session_id, official_fc=False)
- if isinstance(result, Function):
- logger.debug(f"function-calling: {result}")
- func_obj = None
- for i in new_func_call.func_list:
- if i["name"] == result.name:
- func_obj = i["func_obj"]
- break
- if not func_obj:
- return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
- try:
- args = json.loads(result.arguments)
- function_invoked_ret = await func_obj(**args)
- has_func = True
- except BaseException as e:
- traceback.print_exc()
- return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
- else:
- return result
- else:
- # we use our own function-calling
- try:
- args = {
- 'question': prompt,
- 'func_definition': new_func_call.func_dump(),
- }
- function_invoked_ret, has_func = await new_func_call.func_call(**args)
-
- if not has_func:
- return await provider.text_chat(prompt, session_id)
-
- except BaseException as e:
- logger.error(traceback.format_exc())
- return await provider.text_chat(prompt, session_id) + "(网页搜索失败, 此为默认回复)"
-
- if has_func:
- await provider.forget(session_id=session_id)
- summary_prompt = f"""
+async def summarize(context: Context, ame: AstrMessageEvent, text: str):
+
+ summary_prompt = f"""
你是一个专业且高效的助手,你的任务是
-1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
+1. 根据下面的相关材料对用户的问题 `{ame.message_str}` 进行总结;
2. 简单地发表你对这个问题的看法。
# 例子
@@ -183,7 +101,7 @@ async def web_search(prompt: str, provider: Provider, session_id: str, official_
2. 请**直接输出总结**,不要输出多余的内容和提示语。
# 相关材料
-{function_invoked_ret}"""
- ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
- return ret
- return function_invoked_ret
+{text}"""
+
+ provider = context.get_current_llm_provider()
+ return await provider.text_chat(prompt=summary_prompt, session_id=ame.session_id)
\ No newline at end of file
From 98863ab90160822546c4e6f65b6030a6081d2160 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 8 Sep 2024 08:16:36 -0400
Subject: [PATCH 44/47] feat: customized tool-use
---
astrbot/message/handler.py | 4 +++-
util/agent/func_call.py | 1 -
util/agent/web_searcher.py | 2 --
3 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index 89383a463..a20726435 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -214,7 +214,9 @@ class MessageHandler():
return MessageResult("AstrBot Function-calling 异常:未找到请求的函数调用。")
try:
args = json.loads(llm_result.arguments)
- function_invoked_ret = await func_obj(**args)
+ args['ame'] = message
+ args['context'] = self.context
+ llm_result = await func_obj(**args)
has_func = True
except BaseException as e:
traceback.print_exc()
diff --git a/util/agent/func_call.py b/util/agent/func_call.py
index 5283ee4d6..830496cab 100644
--- a/util/agent/func_call.py
+++ b/util/agent/func_call.py
@@ -1,6 +1,5 @@
from model.provider.provider import Provider
import json
-import time
import textwrap
class FuncCallJsonFormatError(Exception):
diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py
index e519ca697..d9b384314 100644
--- a/util/agent/web_searcher.py
+++ b/util/agent/web_searcher.py
@@ -1,6 +1,4 @@
-import traceback
import random
-import json
import aiohttp
import os
From 5b72ebaad50a1f344706cf4842f663143915919c Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 8 Sep 2024 08:23:43 -0400
Subject: [PATCH 45/47] delete: remove deprecated files
---
model/command/adapter/nonebot/command_arg.py | 2 -
model/command/adapter/nonebot/common.py | 32 ----
model/command/adapter/nonebot/driver.py | 11 --
model/command/adapter/onebot/bot.py | 2 -
model/command/adapter/onebot/message.py | 2 -
model/command/adapter/onebot/message_event.py | 2 -
.../command/adapter/onebot/message_segment.py | 2 -
model/command/adapter/protocol_adapter.py | 138 ------------------
8 files changed, 191 deletions(-)
delete mode 100644 model/command/adapter/nonebot/command_arg.py
delete mode 100644 model/command/adapter/nonebot/common.py
delete mode 100644 model/command/adapter/nonebot/driver.py
delete mode 100644 model/command/adapter/onebot/bot.py
delete mode 100644 model/command/adapter/onebot/message.py
delete mode 100644 model/command/adapter/onebot/message_event.py
delete mode 100644 model/command/adapter/onebot/message_segment.py
delete mode 100644 model/command/adapter/protocol_adapter.py
diff --git a/model/command/adapter/nonebot/command_arg.py b/model/command/adapter/nonebot/command_arg.py
deleted file mode 100644
index 3f6e8a305..000000000
--- a/model/command/adapter/nonebot/command_arg.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class CommandArg:
- pass
\ No newline at end of file
diff --git a/model/command/adapter/nonebot/common.py b/model/command/adapter/nonebot/common.py
deleted file mode 100644
index ac7532768..000000000
--- a/model/command/adapter/nonebot/common.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import sys
-from types import ModuleType
-import asyncio
-from pyppeteer import launch
-
-
-async def template_to_pic(template_path, template_name, templates, pages, wait, type, quality, device_scale_factor):
- browser = await launch()
- page = await browser.newPage()
- await page.setViewport(pages["viewport"])
- await page.goto(pages["base_url"])
- await asyncio.sleep(wait)
- await page.evaluate('''(templates) => {
- // 在页面中执行 JavaScript 代码,将数据注入到模板中
- // 这里的示例代码仅供参考,具体需要根据实际情况修改
- document.getElementById('css').innerText = templates.css;
- document.getElementById('data').innerText = JSON.stringify(templates.data);
- document.getElementById('detail').innerText = templates.detail;
- }''', templates)
- screenshot = await page.screenshot({
- 'type': type,
- 'quality': quality,
- 'deviceScaleFactor': device_scale_factor
- })
- await browser.close()
- return screenshot
-
-def require(module_str: str):
- module = ModuleType(module_str)
- sys.modules[module_str] = module
- if module_str == 'nonebot_plugin_htmlrender':
- module.template_to_pic = template_to_pic
diff --git a/model/command/adapter/nonebot/driver.py b/model/command/adapter/nonebot/driver.py
deleted file mode 100644
index a184fbcaf..000000000
--- a/model/command/adapter/nonebot/driver.py
+++ /dev/null
@@ -1,11 +0,0 @@
-class Driver:
- def __init__(self) -> None:
- self.config = {}
-
- def on_startup(self, func):
- pass
- def on_bot_connect(self, func):
- pass
-
-def get_driver():
- return Driver()
\ No newline at end of file
diff --git a/model/command/adapter/onebot/bot.py b/model/command/adapter/onebot/bot.py
deleted file mode 100644
index 19be4bcb3..000000000
--- a/model/command/adapter/onebot/bot.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class Bot:
- pass
\ No newline at end of file
diff --git a/model/command/adapter/onebot/message.py b/model/command/adapter/onebot/message.py
deleted file mode 100644
index 574b0ce16..000000000
--- a/model/command/adapter/onebot/message.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class Message:
- pass
\ No newline at end of file
diff --git a/model/command/adapter/onebot/message_event.py b/model/command/adapter/onebot/message_event.py
deleted file mode 100644
index 11712e075..000000000
--- a/model/command/adapter/onebot/message_event.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class MessageEvent:
- pass
\ No newline at end of file
diff --git a/model/command/adapter/onebot/message_segment.py b/model/command/adapter/onebot/message_segment.py
deleted file mode 100644
index 2905a1757..000000000
--- a/model/command/adapter/onebot/message_segment.py
+++ /dev/null
@@ -1,2 +0,0 @@
-class MessageSegment:
- pass
\ No newline at end of file
diff --git a/model/command/adapter/protocol_adapter.py b/model/command/adapter/protocol_adapter.py
deleted file mode 100644
index 3e182d4d7..000000000
--- a/model/command/adapter/protocol_adapter.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import sys
-from types import ModuleType
-import asyncio
-from pyppeteer import launch
-
-from model.platform.qqchan import QQChan
-
-from .nonebot.driver import Driver, get_driver
-from .onebot.message import Message
-from .onebot.message_event import MessageEvent
-from .onebot.message_segment import MessageSegment
-from .nonebot.command_arg import CommandArg
-from .onebot.bot import Bot
-from .nonebot.common import require
-
-from nakuru import (
- GuildMessage,
- GroupMessage,
- FriendMessage
-)
-
-from typing import Union
-
-NONEBOT = "nonebot"
-
-class UnifiedBotCompatibleLayer():
- def __init__(self, platform_qq_sdk: QQChan) -> None:
- # 初始化兼容层
- self.plugins: dict[str, CommandOper] = {}
- self.platform_qq_sdk = platform_qq_sdk
- self._nonebot()
- self.load_plugins()
-
- async def check_commands(self, message: str, message_obj: Union[GroupMessage, FriendMessage, GuildMessage]):
- for k in self.plugins:
- if message.startswith(k):
- if self.plugins[k].framework_name == NONEBOT:
- await self._nonebot_plugins_oper(message, message_obj, k)
-
- async def _nonebot_plugins_oper(self, message: str, message_obj: Union[GroupMessage, FriendMessage, GuildMessage], plugin_name: str = None):
- # bad implementation
- # 高并发场景下,下面的代码是不安全的
- while self.plugins[plugin_name].message_obj is not None:
- await asyncio.sleep(1)
- self.plugins[plugin_name].message_obj = message_obj
- bot, event, arg = self._nonebot_adapter(message_obj)
- await self.plugins[plugin_name].exec(bot, event, arg) # wrapper
-
- def load_plugins(self):
- import nonebot_plugin_gspanel.nonebot_plugin_gspanel
-
- def _nonebot(self):
- # 模拟 nonebot 模块
- nonebot_module = ModuleType('nonebot')
- sys.modules['nonebot'] = nonebot_module
-
- nonebot_log_module = ModuleType('nonebot.log')
- sys.modules['nonebot.log'] = nonebot_log_module
-
- nonebot_adapter_module = ModuleType('nonebot.adapters')
- sys.modules['nonebot.adapters'] = nonebot_adapter_module
-
- nonebot_params_module = ModuleType('nonebot.params')
- sys.modules['nonebot.params'] = nonebot_params_module
-
- nonebot_drivers_module = ModuleType('nonebot.drivers')
- sys.modules['nonebot.drivers'] = nonebot_drivers_module
-
- nonebot_plugin_module = ModuleType('nonebot.plugin')
- sys.modules['nonebot.plugin'] = nonebot_plugin_module
-
- nonebot_adapter_onebot_v11_module = ModuleType('nonebot.adapters.onebot.v11')
- sys.modules['nonebot.adapters.onebot.v11'] = nonebot_adapter_onebot_v11_module
-
- nonebot_adapter_onebot_v11_event_module = ModuleType('nonebot.adapters.onebot.v11.event')
- sys.modules['nonebot.adapters.onebot.v11.event'] = nonebot_adapter_onebot_v11_event_module
-
- nonebot_adapter_onebot_v11_message_module = ModuleType('nonebot.adapters.onebot.v11.message')
- sys.modules['nonebot.adapters.onebot.v11.message'] = nonebot_adapter_onebot_v11_message_module
-
- nonebot_log_module.logger = lambda: None
- nonebot_adapter_module.Message = Message
- nonebot_params_module.CommandArg = CommandArg
- on_command = wrap_on_command(self)
- nonebot_plugin_module.on_command = on_command
- nonebot_adapter_onebot_v11_module.Bot = Bot
- nonebot_adapter_onebot_v11_event_module.MessageEvent = MessageEvent
- nonebot_adapter_onebot_v11_message_module.MessageSegment = MessageSegment
- nonebot_module.get_driver = get_driver
- nonebot_module.require = require
- nonebot_drivers_module.Driver = Driver
-
- def _nonebot_adapter(self, message_obj):
- bot = Bot()
- event = MessageEvent()
- arg = CommandArg()
- # tododssss
- return bot, event, arg
-
-
-class BaseBot():
- def __init__(self, framework_name) -> None:
- self.framework_name = framework_name
-
-class CommandOper(BaseBot):
- '''
- CommandOper for NoneBot
- '''
- def __init__(self, name, aliases=None, priority=1, block=False, _ubcl: UnifiedBotCompatibleLayer = None) -> None:
- super().__init__("nonebot")
- self.name = name
- self.aliases = aliases
- self.priority = priority
- self.block = block
- self.exec = None
- self._ubcl = _ubcl
- self.message_obj: Union[GroupMessage, FriendMessage, GuildMessage] = None
- _ubcl.plugins[name] = self
-
- def handle(self):
- def decorator(func):
- async def wrapper(bot: Bot, event: MessageEvent, arg: Message = CommandArg(), *args, **kwargs):
- # 你可以在这里添加自定义的处理逻辑
- print(f"Command {self.name} is executed.")
- await func(bot, event, arg, *args, **kwargs)
- self.exec = wrapper
- return wrapper
- return decorator
-
- async def finish(self, msg, at_sender = True):
- if self.message_obj is not None:
- self._ubcl.platform_qq_sdk.send(self.message_obj, msg)
- self.message_obj = None
-
-def wrap_on_command(_ubcl: UnifiedBotCompatibleLayer):
- def on_command(name, aliases=None, priority=1, block=False):
- return CommandOper(name, aliases, priority, block, _ubcl = _ubcl)
- return on_command
From d8e70c4d7fe06e57ebe591162bb0918b5aede7b4 Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Sun, 8 Sep 2024 08:41:26 -0400
Subject: [PATCH 46/47] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20llm=20tool?=
=?UTF-8?q?=20=E8=BF=94=E5=9B=9E=E5=80=BC=E5=A4=84=E7=90=86?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
addons/plugins/helloworld/main.py | 8 ++++++++
astrbot/message/handler.py | 20 ++++++++++++++++++--
2 files changed, 26 insertions(+), 2 deletions(-)
diff --git a/addons/plugins/helloworld/main.py b/addons/plugins/helloworld/main.py
index 56eb99162..25f8e6974 100644
--- a/addons/plugins/helloworld/main.py
+++ b/addons/plugins/helloworld/main.py
@@ -21,6 +21,14 @@ class HelloWorldPlugin:
def __init__(self, context: Context) -> None:
self.context = context
self.context.register_commands("helloworld", "helloworld", "内置测试指令。", 1, self.helloworld)
+ self.context.register_llm_tool("welcome_somebody", [{
+ "type": "string",
+ "name": "name",
+ "description": "要欢迎的人的名字"
+ }], "给一个用户发送欢迎文本。", self.welcome_somebody)
+
+ async def welcome_somebody(self, name: str):
+ return CommandResult().message(f"欢迎{name}!")
"""
指令处理函数。
diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py
index a20726435..c09cd6321 100644
--- a/astrbot/message/handler.py
+++ b/astrbot/message/handler.py
@@ -216,8 +216,24 @@ class MessageHandler():
args = json.loads(llm_result.arguments)
args['ame'] = message
args['context'] = self.context
- llm_result = await func_obj(**args)
- has_func = True
+ try:
+ cmd_res = await func_obj(**args)
+ except TypeError as e:
+ args.pop('ame')
+ args.pop('context')
+ cmd_res = await func_obj(**args)
+ if isinstance(cmd_res, CommandResult):
+ return MessageResult(
+ cmd_res.message_chain,
+ is_command_call=True,
+ use_t2i=cmd_res.is_use_t2i
+ )
+ elif isinstance(cmd_res, str):
+ return MessageResult(cmd_res)
+ elif not cmd_res:
+ return
+ else:
+ return MessageResult(f"AstrBot Function-calling 异常:调用:{llm_result} 时,返回了未知的返回值类型。")
except BaseException as e:
traceback.print_exc()
return MessageResult("AstrBot Function-calling 异常:" + str(e))
From 7e48514f67903e9524b79c8ceee40a52cbfce187 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Sun, 8 Sep 2024 21:06:20 +0800
Subject: [PATCH 47/47] Update README.md
---
README.md | 1 -
1 file changed, 1 deletion(-)
diff --git a/README.md b/README.md
index de6d25467..9cd432c9c 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,6 @@
🌍 支持的消息平台
- QQ 群、QQ 频道(OneBot、QQ 官方接口)
- Telegram([astrbot_plugin_telegram](https://github.com/Soulter/astrbot_plugin_telegram) 插件)
-- WeChat(微信) ([astrbot_plugin_vchat](https://github.com/z2z63/astrbot_plugin_vchat) 插件)
🌍 支持的大模型/底座: