feat: unit test
perf: func call improvement
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
[run]
|
||||
omit =
|
||||
*/site-packages/*
|
||||
*/dist-packages/*
|
||||
your_package_name/tests/*
|
||||
@@ -22,7 +22,7 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
self.config_helper = CmdConfig()
|
||||
|
||||
@@ -57,6 +57,8 @@ class AstrBotBootstrap():
|
||||
logger.info(f"使用代理: {http_proxy}, {https_proxy}")
|
||||
else:
|
||||
logger.info("未使用代理。")
|
||||
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
async def run(self):
|
||||
self.command_manager = CommandManager()
|
||||
@@ -80,6 +82,9 @@ class AstrBotBootstrap():
|
||||
self.context.message_handler = self.message_handler
|
||||
self.context.command_manager = self.command_manager
|
||||
|
||||
if self.test_mode:
|
||||
return
|
||||
|
||||
# load plugins, plugins' commands.
|
||||
self.load_plugins()
|
||||
self.command_manager.register_from_pcb(self.context.plugin_command_bridge)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
import re
|
||||
import re, os
|
||||
import asyncio
|
||||
import traceback
|
||||
import astrbot.message.unfit_words as uw
|
||||
@@ -14,6 +14,7 @@ from type.command import CommandResult
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.func_call import FuncCall
|
||||
import util.agent.web_searcher as web_searcher
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -117,6 +118,8 @@ class MessageHandler():
|
||||
self.nicks = self.context.nick
|
||||
self.provider = provider
|
||||
self.reply_prefix = str(self.context.reply_prefix)
|
||||
|
||||
self.llm_tools = FuncCall(self.provider)
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
@@ -128,21 +131,20 @@ class MessageHandler():
|
||||
`llm_provider`: the provider to use for LLM. If None, use the default provider
|
||||
'''
|
||||
msg_plain = message.message_str.strip()
|
||||
provider = llm_provider if llm_provider else self.provider
|
||||
inner_provider = False if llm_provider else True
|
||||
provider = llm_provider if llm_provider else self.provider
|
||||
|
||||
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
|
||||
if os.environ.get('TEST_MODE', 'off') != 'on':
|
||||
self.persist_manager.record_message(message.platform.platform_name, message.session_id)
|
||||
|
||||
# TODO: this should be configurable
|
||||
# if not message.message_str:
|
||||
# return MessageResult("Hi~")
|
||||
|
||||
# check the rate limit
|
||||
if not message.only_command and not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
|
||||
# return MessageResult(f'你的发言超过频率限制(╯▔皿▔)╯。\n管理员设置 {self.rate_limit_helper.rate_limit_time} 秒内只能提问{self.rate_limit_helper.rate_limit_count} 次。')
|
||||
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制, 跳过。")
|
||||
if not self.rate_limit_helper.check_frequency(message.message_obj.sender.user_id):
|
||||
logger.warning(f"用户 {message.message_obj.sender.user_id} 的发言频率超过限制,已忽略。")
|
||||
return
|
||||
|
||||
|
||||
# remove the nick prefix
|
||||
for nick in self.nicks:
|
||||
if msg_plain.startswith(nick):
|
||||
@@ -183,6 +185,7 @@ class MessageHandler():
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
|
||||
web_search = self.context.web_search
|
||||
if not web_search and msg_plain.startswith("ws"):
|
||||
# leverage web search feature
|
||||
|
||||
@@ -210,6 +210,9 @@ class AIOCQHTTP(Platform):
|
||||
if isinstance(segment, Image):
|
||||
image_idx.append(idx)
|
||||
ret.append(d)
|
||||
if os.environ.get('TEST_MODE', 'off') == 'on':
|
||||
logger.info(f"回复消息: {ret}")
|
||||
return
|
||||
try:
|
||||
if isinstance(message, AstrBotMessage):
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
|
||||
@@ -52,7 +52,7 @@ class botClient(Client):
|
||||
|
||||
class QQOfficial(Platform):
|
||||
|
||||
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
super().__init__("qqofficial", context)
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
@@ -87,7 +87,7 @@ class QQOfficial(Platform):
|
||||
|
||||
self.client.set_platform(self)
|
||||
|
||||
self.test_mode = test_mode
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
|
||||
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
|
||||
plain_text = ""
|
||||
|
||||
@@ -296,6 +296,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
extra_conf: Dict = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
if os.environ.get("TEST_LLM", "off") == "on":
|
||||
return "这是一个测试消息。"
|
||||
|
||||
super().accu_model_stat()
|
||||
if not session_id:
|
||||
session_id = "unknown"
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from aiocqhttp import Event
|
||||
|
||||
class MockOneBotMessage():
|
||||
def __init__(self):
|
||||
# 这些数据不是敏感的
|
||||
self.group_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882500, 'message_id': -2147480159, 'message_seq': -2147480159, 'real_id': -2147480159, 'message_type': 'group', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': '', 'role': 'owner'}, 'raw_message': '[CQ:at,qq=3430871669] just reply me `ok`', 'font': 14, 'sub_type': 'normal', 'message': [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': ' just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message', 'group_id': 849750470})
|
||||
self.friend_event_sample = Event.from_payload({'self_id': 3430871669, 'user_id': 905617992, 'time': 1723882599, 'message_id': -2147480157, 'message_seq': -2147480157, 'real_id': -2147480157, 'message_type': 'private', 'sender': {'user_id': 905617992, 'nickname': 'Soulter', 'card': ''}, 'raw_message': 'just reply me `ok`', 'font': 14, 'sub_type': 'friend', 'message': [{'data': {'text': 'just reply me `ok`'}, 'type': 'text'}], 'message_format': 'array', 'post_type': 'message'})
|
||||
|
||||
def create_random_group_message(self):
|
||||
return self.group_event_sample
|
||||
|
||||
def create_random_direct_message(self):
|
||||
return self.friend_event_sample
|
||||
@@ -0,0 +1,45 @@
|
||||
import botpy.message
|
||||
|
||||
class MockQQOfficialMessage():
|
||||
def __init__(self):
|
||||
# 这些数据已经经过去敏处理
|
||||
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'}
|
||||
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'}
|
||||
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'}
|
||||
self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
|
||||
|
||||
self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
|
||||
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
|
||||
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
|
||||
self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
|
||||
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
|
||||
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
|
||||
self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
def create_random_group_message(self):
|
||||
mocked = botpy.message.GroupMessage(
|
||||
api=None,
|
||||
event_id=self.group_event_id_sample,
|
||||
data=self.group_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_random_guild_message(self):
|
||||
mocked = botpy.message.Message(
|
||||
api=None,
|
||||
event_id=self.guild_event_id_sample,
|
||||
data=self.guild_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_random_direct_message(self):
|
||||
mocked = botpy.message.DirectMessage(
|
||||
api=None,
|
||||
event_id=self.direct_event_id_sample,
|
||||
data=self.direct_plain_text_sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import asyncio
|
||||
import pytest
|
||||
import os
|
||||
|
||||
from tests.mocks.qq_official import MockQQOfficialMessage
|
||||
from tests.mocks.onebot import MockOneBotMessage
|
||||
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
custom_formatter=Formatter('[%(asctime)s| %(name)s - %(levelname)s|%(filename)s:%(lineno)d]: %(message)s', datefmt="%H:%M:%S")
|
||||
)
|
||||
pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
os.environ['TEST_MODE'] = 'on'
|
||||
bootstrap = AstrBotBootstrap()
|
||||
asyncio.run(bootstrap.run())
|
||||
|
||||
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler)
|
||||
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler)
|
||||
|
||||
class TestBasicMessageHandle():
|
||||
@pytest.mark.asyncio
|
||||
async def test_qqofficial_group_message(self):
|
||||
group_message = MockQQOfficialMessage().create_random_group_message()
|
||||
abm = qq_official._parse_from_qqofficial(group_message, MessageType.GROUP_MESSAGE)
|
||||
ret = await qq_official.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qqofficial_guild_message(self):
|
||||
guild_message = MockQQOfficialMessage().create_random_guild_message()
|
||||
abm = qq_official._parse_from_qqofficial(guild_message, MessageType.GUILD_MESSAGE)
|
||||
ret = await qq_official.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
# 有共同性,为了节约开销,不测试频道私聊。
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_qqofficial_private_message(self):
|
||||
# private_message = MockQQOfficialMessage().create_random_direct_message()
|
||||
# abm = qq_official._parse_from_qqofficial(private_message, MessageType.FRIEND_MESSAGE)
|
||||
# ret = await qq_official.handle_msg(abm)
|
||||
# print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_group_message(self):
|
||||
event = MockOneBotMessage().create_random_group_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aiocqhttp_direct_message(self):
|
||||
event = MockOneBotMessage().create_random_direct_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
+16
-3
@@ -1,4 +1,4 @@
|
||||
import asyncio
|
||||
import asyncio, os
|
||||
from asyncio import Task
|
||||
from type.register import *
|
||||
from typing import List, Awaitable
|
||||
@@ -12,6 +12,7 @@ from type.command import CommandResult
|
||||
from type.astrbot_message import MessageType
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from model.provider.provider import Provider
|
||||
from util.agent.func_call import FuncCall
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -97,13 +98,25 @@ class Context:
|
||||
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
|
||||
'''
|
||||
self.llms.append(RegisteredLLM(llm_name, provider, origin))
|
||||
|
||||
def register_llm_tool(self, tool_name: str, params: list, desc: str, func: callable):
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
self.message_handler.llm_tools.add_func(tool_name, params, desc, func)
|
||||
|
||||
def find_platform(self, platform_name: str) -> RegisteredPlatform:
|
||||
for platform in self.platforms:
|
||||
if platform_name == platform.platform_name:
|
||||
return platform
|
||||
|
||||
raise ValueError("couldn't find the platform you specified")
|
||||
|
||||
if not os.environ.get('TEST_MODE', 'off') == 'on': # 测试模式下不报错
|
||||
raise ValueError("couldn't find the platform you specified")
|
||||
|
||||
async def send_message(self, unified_msg_origin: str, message: CommandResult):
|
||||
'''
|
||||
|
||||
+13
-2
@@ -25,6 +25,14 @@ class FuncCall():
|
||||
self.provider = provider
|
||||
|
||||
def add_func(self, name: str, func_args: list, desc: str, func_obj: callable) -> None:
|
||||
'''
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
@param name: 函数名
|
||||
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
|
||||
@param desc: 函数描述
|
||||
@param func_obj: 处理函数
|
||||
'''
|
||||
params = {
|
||||
"type": "object", # hardcore here
|
||||
"properties": {}
|
||||
@@ -65,7 +73,10 @@ class FuncCall():
|
||||
})
|
||||
return _l
|
||||
|
||||
async def func_call(self, question: str, func_definition: str, session_id: str=None):
|
||||
async def func_call(self, question: str, func_definition: str, session_id: str, provider: Provider = None) -> tuple:
|
||||
|
||||
if not provider:
|
||||
provider = self.provider
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
@@ -91,7 +102,7 @@ class FuncCall():
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await self.provider.text_chat(prompt, session_id)
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
print(res)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
|
||||
Reference in New Issue
Block a user