From dcec3f5f8421cd0af5a44f2749ea57e8deed5976 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 17 Aug 2024 04:46:23 -0400 Subject: [PATCH] feat: unit test perf: func call improvement --- .coveragerc | 5 +++ astrbot/bootstrap.py | 7 +++- astrbot/message/handler.py | 19 +++++---- model/platform/qq_aiocqhttp.py | 3 ++ model/platform/qq_official.py | 4 +- model/provider/openai_official.py | 3 ++ tests/mocks/onebot.py | 13 +++++++ tests/mocks/qq_official.py | 45 +++++++++++++++++++++ tests/test_message.py | 65 +++++++++++++++++++++++++++++++ type/types.py | 19 +++++++-- util/agent/func_call.py | 15 ++++++- 11 files changed, 182 insertions(+), 16 deletions(-) create mode 100644 .coveragerc create mode 100644 tests/mocks/onebot.py create mode 100644 tests/mocks/qq_official.py create mode 100644 tests/test_message.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..1385093f4 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit = + */site-packages/* + */dist-packages/* + your_package_name/tests/* \ No newline at end of file diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 36c0f010f..0eb376769 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -22,7 +22,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') class AstrBotBootstrap(): - def __init__(self) -> None: + def __init__(self) -> None: self.context = Context() self.config_helper = CmdConfig() @@ -57,6 +57,8 @@ class AstrBotBootstrap(): logger.info(f"使用代理: {http_proxy}, {https_proxy}") else: logger.info("未使用代理。") + + self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' async def run(self): self.command_manager = CommandManager() @@ -80,6 +82,9 @@ class AstrBotBootstrap(): self.context.message_handler = self.message_handler self.context.command_manager = self.command_manager + if self.test_mode: + return + # load plugins, plugins' commands. self.load_plugins() self.command_manager.register_from_pcb(self.context.plugin_command_bridge) diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 7b455b4d9..c3d26305e 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -1,5 +1,5 @@ import time -import re +import re, os import asyncio import traceback import astrbot.message.unfit_words as uw @@ -14,6 +14,7 @@ from type.command import CommandResult from SparkleLogging.utils.core import LogManager from logging import Logger from nakuru.entities.components import Image +from util.agent.func_call import FuncCall import util.agent.web_searcher as web_searcher logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -117,6 +118,8 @@ class MessageHandler(): self.nicks = self.context.nick self.provider = provider self.reply_prefix = str(self.context.reply_prefix) + + self.llm_tools = FuncCall(self.provider) def set_provider(self, provider: Provider): self.provider = provider @@ -128,21 +131,20 @@ class MessageHandler(): `llm_provider`: the provider to use for LLM. If None, use the default provider ''' msg_plain = message.message_str.strip() - provider = llm_provider if llm_provider else self.provider - inner_provider = False if llm_provider else True + provider = llm_provider if llm_provider else self.provider - self.persist_manager.record_message(message.platform.platform_name, message.session_id) + if os.environ.get('TEST_MODE', 'off') != 'on': + self.persist_manager.record_message(message.platform.platform_name, message.session_id) # TODO: this should be configurable # if not message.message_str: # return MessageResult("Hi~") # check the rate limit - if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): - # return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。') - logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。") + if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id): + logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。") return - + # remove the nick prefix for nick in self.nicks: if msg_plain.startswith(nick): @@ -183,6 +185,7 @@ class MessageHandler(): if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file break + web_search = self.context.web_search if not web_search and msg_plain.startswith("ws"): # leverage web search feature diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index c70e6bef3..49299201d 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -210,6 +210,9 @@ class AIOCQHTTP(Platform): if isinstance(segment, Image): image_idx.append(idx) ret.append(d) + if os.environ.get('TEST_MODE', 'off') == 'on': + logger.info(f"回复消息: {ret}") + return try: if isinstance(message, AstrBotMessage): await self.bot.send(message.raw_message, ret) diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index d0420263c..0ca3aeea4 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -52,7 +52,7 @@ class botClient(Client): class QQOfficial(Platform): - def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None: + def __init__(self, context: Context, message_handler: MessageHandler) -> None: super().__init__("qqofficial", context) self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) @@ -87,7 +87,7 @@ class QQOfficial(Platform): self.client.set_platform(self) - self.test_mode = test_mode + self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False): plain_text = "" diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index f3e9ab8e8..7ecbeab7d 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -296,6 +296,9 @@ class ProviderOpenAIOfficial(Provider): extra_conf: Dict = None, **kwargs ) -> str: + if os.environ.get("TEST_LLM", "off") == "on": + return "这是一个测试消息。" + super().accu_model_stat() if not session_id: session_id = "unknown" diff --git a/tests/mocks/onebot.py b/tests/mocks/onebot.py new file mode 100644 index 000000000..66df3d1ee --- /dev/null +++ b/tests/mocks/onebot.py @@ -0,0 +1,13 @@ +from aiocqhttp import Event + +class MockOneBotMessage(): + def __init__(self): + # 这些数据不是敏感的 + self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470}) + self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'}) + + def create_random_group_message(self): + return self.group_event_sample + + def create_random_direct_message(self): + return self.friend_event_sample \ No newline at end of file diff --git a/tests/mocks/qq_official.py b/tests/mocks/qq_official.py new file mode 100644 index 000000000..0d665d289 --- /dev/null +++ b/tests/mocks/qq_official.py @@ -0,0 +1,45 @@ +import botpy.message + +class MockQQOfficialMessage(): + def __init__(self): + # 这些数据已经经过去敏处理 + self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'} + self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'} + self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'} + self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849" + + self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'} + self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'} + self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'} + self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" + + self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'} + self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'} + self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'} + self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" + + def create_random_group_message(self): + mocked = botpy.message.GroupMessage( + api=None, + event_id=self.group_event_id_sample, + data=self.group_plain_text_sample + ) + return mocked + + def create_random_guild_message(self): + mocked = botpy.message.Message( + api=None, + event_id=self.guild_event_id_sample, + data=self.guild_plain_text_sample + ) + return mocked + + def create_random_direct_message(self): + mocked = botpy.message.DirectMessage( + api=None, + event_id=self.direct_event_id_sample, + data=self.direct_plain_text_sample + ) + return mocked + + diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 000000000..a5fc4578a --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,65 @@ +import asyncio +import pytest +import os + +from tests.mocks.qq_official import MockQQOfficialMessage +from tests.mocks.onebot import MockOneBotMessage + +from astrbot.bootstrap import AstrBotBootstrap +from model.platform.qq_official import QQOfficial +from model.platform.qq_aiocqhttp import AIOCQHTTP +from type.astrbot_message import * +from type.message_event import * +from SparkleLogging.utils.core import LogManager +from logging import Formatter + +logger = LogManager.GetLogger( +log_name='astrbot', + out_to_console=True, + custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S") +) +pytest_plugins = ('pytest_asyncio',) + +os.environ['TEST_MODE'] = 'on' +bootstrap = AstrBotBootstrap() +asyncio.run(bootstrap.run()) + +qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler) +aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler) + +class TestBasicMessageHandle(): + @pytest.mark.asyncio + async def test_qqofficial_group_message(self): + group_message = MockQQOfficialMessage().create_random_group_message() + abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE) + ret = await qq_official.handle_msg(abm) + print(ret) + + @pytest.mark.asyncio + async def test_qqofficial_guild_message(self): + guild_message = MockQQOfficialMessage().create_random_guild_message() + abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE) + ret = await qq_official.handle_msg(abm) + print(ret) + + # 有共同性,为了节约开销,不测试频道私聊。 + # @pytest.mark.asyncio + # async def test_qqofficial_private_message(self): + # private_message = MockQQOfficialMessage().create_random_direct_message() + # abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE) + # ret = await qq_official.handle_msg(abm) + # print(ret) + + @pytest.mark.asyncio + async def test_aiocqhttp_group_message(self): + event = MockOneBotMessage().create_random_group_message() + abm = aiocqhttp.convert_message(event) + ret = await aiocqhttp.handle_msg(abm) + print(ret) + + @pytest.mark.asyncio + async def test_aiocqhttp_direct_message(self): + event = MockOneBotMessage().create_random_direct_message() + abm = aiocqhttp.convert_message(event) + ret = await aiocqhttp.handle_msg(abm) + print(ret) \ No newline at end of file diff --git a/type/types.py b/type/types.py index 06aad685f..fd4d07d9c 100644 --- a/type/types.py +++ b/type/types.py @@ -1,4 +1,4 @@ -import asyncio +import asyncio, os from asyncio import Task from type.register import * from typing import List, Awaitable @@ -12,6 +12,7 @@ from type.command import CommandResult from type.astrbot_message import MessageType from model.plugin.command import PluginCommandBridge from model.provider.provider import Provider +from util.agent.func_call import FuncCall class Context: @@ -97,13 +98,25 @@ class Context: `provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。 ''' self.llms.append(RegisteredLLM(llm_name, provider, origin)) + + def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable): + ''' + 为函数调用(function-calling / tools-use)添加工具。 + + @param name: 函数名 + @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + @param desc: 函数描述 + @param func_obj: 处理函数 + ''' + self.message_handler.llm_tools.add_func(tool_name, params, desc, func) def find_platform(self, platform_name: str) -> RegisteredPlatform: for platform in self.platforms: if platform_name == platform.platform_name: return platform - - raise ValueError("couldn't find the platform you specified") + + if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错 + raise ValueError("couldn't find the platform you specified") async def send_message(self, unified_msg_origin: str, message: CommandResult): ''' diff --git a/util/agent/func_call.py b/util/agent/func_call.py index 5bb602781..e805f5bfc 100644 --- a/util/agent/func_call.py +++ b/util/agent/func_call.py @@ -25,6 +25,14 @@ class FuncCall(): self.provider = provider def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None: + ''' + 为函数调用(function-calling / tools-use)添加工具。 + + @param name: 函数名 + @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] + @param desc: 函数描述 + @param func_obj: 处理函数 + ''' params = { "type": "object", # hardcore here "properties": {} @@ -65,7 +73,10 @@ class FuncCall(): }) return _l - async def func_call(self, question: str, func_definition: str, session_id: str=None): + async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple: + + if not provider: + provider = self.provider prompt = textwrap.dedent(f""" ROLE: @@ -91,7 +102,7 @@ class FuncCall(): _c = 0 while _c < 3: try: - res = await self.provider.text_chat(prompt, session_id) + res = await provider.text_chat(prompt, session_id) print(res) if res.find('```') != -1: res = res[res.find('```json') + 7: res.rfind('```')]