feat: unit test

perf: func call improvement
This commit is contained in:
Soulter
2024-08-17 04:46:23 -04:00
parent 32e2a7830a
commit dcec3f5f84
11 changed files with 182 additions and 16 deletions
+5
View File
@@ -0,0 +1,5 @@
[run]
omit =
*/site-packages/*
*/dist-packages/*
your_package_name/tests/*
+6 -1
View File
@@ -22,7 +22,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap():
def __init__(self) -> None:
def __init__(self) -> None:
self.context = Context()
self.config_helper = CmdConfig()
@@ -57,6 +57,8 @@ class AstrBotBootstrap():
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
else:
logger.info("未使用代理。")
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
async def run(self):
self.command_manager = CommandManager()
@@ -80,6 +82,9 @@ class AstrBotBootstrap():
self.context.message_handler = self.message_handler
self.context.command_manager = self.command_manager
if self.test_mode:
return
# load plugins, plugins' commands.
self.load_plugins()
self.command_manager.register_from_pcb(self.context.plugin_command_bridge)
+11 -8
View File
@@ -1,5 +1,5 @@
import time
import re
import re, os
import asyncio
import traceback
import astrbot.message.unfit_words as uw
@@ -14,6 +14,7 @@ from type.command import CommandResult
from SparkleLogging.utils.core import LogManager
from logging import Logger
from nakuru.entities.components import Image
from util.agent.func_call import FuncCall
import util.agent.web_searcher as web_searcher
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -117,6 +118,8 @@ class MessageHandler():
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = str(self.context.reply_prefix)
self.llm_tools = FuncCall(self.provider)
def set_provider(self, provider: Provider):
self.provider = provider
@@ -128,21 +131,20 @@ class MessageHandler():
`llm_provider`: the provider to use for LLM. If None, use the default provider
'''
msg_plain = message.message_str.strip()
provider = llm_provider if llm_provider else self.provider
inner_provider = False if llm_provider else True
provider = llm_provider if llm_provider else self.provider
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
if os.environ.get('TEST_MODE', 'off') != 'on':
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
# TODO: this should be configurable
# if not message.message_str:
# return MessageResult("Hi~")
# check the rate limit
if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
# return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。")
return
# remove the nick prefix
for nick in self.nicks:
if msg_plain.startswith(nick):
@@ -183,6 +185,7 @@ class MessageHandler():
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
break
web_search = self.context.web_search
if not web_search and msg_plain.startswith("ws"):
# leverage web search feature
+3
View File
@@ -210,6 +210,9 @@ class AIOCQHTTP(Platform):
if isinstance(segment, Image):
image_idx.append(idx)
ret.append(d)
if os.environ.get('TEST_MODE', 'off') == 'on':
logger.info(f"回复消息: {ret}")
return
try:
if isinstance(message, AstrBotMessage):
await self.bot.send(message.raw_message, ret)
+2 -2
View File
@@ -52,7 +52,7 @@ class botClient(Client):
class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
super().__init__("qqofficial", context)
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
@@ -87,7 +87,7 @@ class QQOfficial(Platform):
self.client.set_platform(self)
self.test_mode = test_mode
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
plain_text = ""
+3
View File
@@ -296,6 +296,9 @@ class ProviderOpenAIOfficial(Provider):
extra_conf: Dict = None,
**kwargs
) -> str:
if os.environ.get("TEST_LLM", "off") == "on":
return "这是一个测试消息。"
super().accu_model_stat()
if not session_id:
session_id = "unknown"
+13
View File
@@ -0,0 +1,13 @@
from aiocqhttp import Event
class MockOneBotMessage():
def __init__(self):
# 这些数据不是敏感的
self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470})
self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'})
def create_random_group_message(self):
return self.group_event_sample
def create_random_direct_message(self):
return self.friend_event_sample
+45
View File
@@ -0,0 +1,45 @@
import botpy.message
class MockQQOfficialMessage():
def __init__(self):
# 这些数据已经经过去敏处理
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'}
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'}
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'}
self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
def create_random_group_message(self):
mocked = botpy.message.GroupMessage(
api=None,
event_id=self.group_event_id_sample,
data=self.group_plain_text_sample
)
return mocked
def create_random_guild_message(self):
mocked = botpy.message.Message(
api=None,
event_id=self.guild_event_id_sample,
data=self.guild_plain_text_sample
)
return mocked
def create_random_direct_message(self):
mocked = botpy.message.DirectMessage(
api=None,
event_id=self.direct_event_id_sample,
data=self.direct_plain_text_sample
)
return mocked
+65
View File
@@ -0,0 +1,65 @@
import asyncio
import pytest
import os
from tests.mocks.qq_official import MockQQOfficialMessage
from tests.mocks.onebot import MockOneBotMessage
from astrbot.bootstrap import AstrBotBootstrap
from model.platform.qq_official import QQOfficial
from model.platform.qq_aiocqhttp import AIOCQHTTP
from type.astrbot_message import *
from type.message_event import *
from SparkleLogging.utils.core import LogManager
from logging import Formatter
logger = LogManager.GetLogger(
log_name='astrbot',
out_to_console=True,
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
)
pytest_plugins = ('pytest_asyncio',)
os.environ['TEST_MODE'] = 'on'
bootstrap = AstrBotBootstrap()
asyncio.run(bootstrap.run())
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler)
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler)
class TestBasicMessageHandle():
@pytest.mark.asyncio
async def test_qqofficial_group_message(self):
group_message = MockQQOfficialMessage().create_random_group_message()
abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE)
ret = await qq_official.handle_msg(abm)
print(ret)
@pytest.mark.asyncio
async def test_qqofficial_guild_message(self):
guild_message = MockQQOfficialMessage().create_random_guild_message()
abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE)
ret = await qq_official.handle_msg(abm)
print(ret)
# 有共同性,为了节约开销,不测试频道私聊。
# @pytest.mark.asyncio
# async def test_qqofficial_private_message(self):
# private_message = MockQQOfficialMessage().create_random_direct_message()
# abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE)
# ret = await qq_official.handle_msg(abm)
# print(ret)
@pytest.mark.asyncio
async def test_aiocqhttp_group_message(self):
event = MockOneBotMessage().create_random_group_message()
abm = aiocqhttp.convert_message(event)
ret = await aiocqhttp.handle_msg(abm)
print(ret)
@pytest.mark.asyncio
async def test_aiocqhttp_direct_message(self):
event = MockOneBotMessage().create_random_direct_message()
abm = aiocqhttp.convert_message(event)
ret = await aiocqhttp.handle_msg(abm)
print(ret)
+16 -3
View File
@@ -1,4 +1,4 @@
import asyncio
import asyncio, os
from asyncio import Task
from type.register import *
from typing import List, Awaitable
@@ -12,6 +12,7 @@ from type.command import CommandResult
from type.astrbot_message import MessageType
from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
from util.agent.func_call import FuncCall
class Context:
@@ -97,13 +98,25 @@ class Context:
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
'''
self.llms.append(RegisteredLLM(llm_name, provider, origin))
def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable):
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
'''
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms:
if platform_name == platform.platform_name:
return platform
raise ValueError("couldn't find the platform you specified")
if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错
raise ValueError("couldn't find the platform you specified")
async def send_message(self, unified_msg_origin: str, message: CommandResult):
'''
+13 -2
View File
@@ -25,6 +25,14 @@ class FuncCall():
self.provider = provider
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
'''
params = {
"type": "object", # hardcore here
"properties": {}
@@ -65,7 +73,10 @@ class FuncCall():
})
return _l
async def func_call(self, question: str, func_definition: str, session_id: str=None):
async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple:
if not provider:
provider = self.provider
prompt = textwrap.dedent(f"""
ROLE:
@@ -91,7 +102,7 @@ class FuncCall():
_c = 0
while _c < 3:
try:
res = await self.provider.text_chat(prompt, session_id)
res = await provider.text_chat(prompt, session_id)
print(res)
if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')]