Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b241b0f954 | |||
| 171dd1dc02 | |||
| af62d969d7 | |||
| c4fd9a66c6 | |||
| d191997a39 | |||
| 853ac4c104 | |||
| ed053acad6 | |||
| f147634e51 | |||
| e3b2a68341 | |||
| 84c450aef9 |
@@ -10,3 +10,4 @@ cmd_config.json
|
||||
data/*
|
||||
cookies.json
|
||||
logs/
|
||||
addons/plugins
|
||||
@@ -101,6 +101,7 @@ class AstrBotBootstrap():
|
||||
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
|
||||
self.llm_instance = ProviderOpenAIOfficial(self.context)
|
||||
self.openai_command_handler.set_provider(self.llm_instance)
|
||||
self.context.register_provider("internal_openai", self.llm_instance)
|
||||
logger.info("已启用 OpenAI API 支持。")
|
||||
|
||||
def load_plugins(self):
|
||||
|
||||
@@ -115,6 +115,9 @@ class MessageHandler():
|
||||
self.nicks = self.context.nick
|
||||
self.provider = provider
|
||||
self.reply_prefix = self.context.reply_prefix
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
|
||||
async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult:
|
||||
'''
|
||||
@@ -148,7 +151,8 @@ class MessageHandler():
|
||||
assert(isinstance(cmd_res, CommandResult))
|
||||
return MessageResult(
|
||||
cmd_res.message_chain,
|
||||
is_command_call=True
|
||||
is_command_call=True,
|
||||
use_t2i=cmd_res.is_use_t2i
|
||||
)
|
||||
|
||||
# check if the message is a llm-wake-up command
|
||||
@@ -178,7 +182,9 @@ class MessageHandler():
|
||||
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
|
||||
else:
|
||||
llm_result = await provider.text_chat(
|
||||
msg_plain, message.session_id, image_url
|
||||
prompt=msg_plain,
|
||||
session_id=message.session_id,
|
||||
image_url=image_url
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
+1
-1
@@ -285,7 +285,7 @@ class AstrBotDashBoard():
|
||||
ret = self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret),
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"has_new_version": ret is not None
|
||||
}
|
||||
|
||||
@@ -26,7 +26,36 @@ class InternalCommandHandler:
|
||||
self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
|
||||
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
|
||||
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
|
||||
|
||||
self.manager.register("provider", "查看和切换当前使用的 LLM 资源来源", 10, self.provider)
|
||||
|
||||
def provider(self, message: AstrMessageEvent, context: Context):
|
||||
if len(context.llms) == 0:
|
||||
return CommandResult().message("当前没有加载任何 LLM 资源。")
|
||||
|
||||
tokens = self.manager.command_parser.parse(message.message_str)
|
||||
|
||||
if tokens.len == 1:
|
||||
ret = "## 当前载入的 LLM 资源\n"
|
||||
for idx, llm in enumerate(context.llms):
|
||||
ret += f"{idx}. {llm.llm_name}"
|
||||
if llm.origin:
|
||||
ret += f" (来源: {llm.origin})"
|
||||
if context.message_handler.provider == llm.llm_instance:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
ret += "\n使用 provider <序号> 切换 LLM 资源。"
|
||||
return CommandResult().message(ret)
|
||||
else:
|
||||
try:
|
||||
idx = int(tokens.get(1))
|
||||
if idx >= len(context.llms):
|
||||
return CommandResult().message("provider: 无效的序号。")
|
||||
context.message_handler.set_provider(context.llms[idx].llm_instance)
|
||||
return CommandResult().message(f"已经成功切换到 LLM 资源 {context.llms[idx].llm_name}。")
|
||||
except BaseException as e:
|
||||
return CommandResult().message("provider: 参数错误。")
|
||||
|
||||
def set_nick(self, message: AstrMessageEvent, context: Context):
|
||||
message_str = message.message_str
|
||||
if message.role != "admin":
|
||||
|
||||
@@ -73,7 +73,8 @@ class CommandManager():
|
||||
if message_str.startswith(command):
|
||||
logger.info(f"触发 {command} 指令。")
|
||||
command_result = await self.execute_handler(command, message_event, context)
|
||||
return command_result
|
||||
if command_result.hit:
|
||||
return command_result
|
||||
|
||||
async def execute_handler(self,
|
||||
command: str,
|
||||
|
||||
@@ -2,6 +2,7 @@ import abc
|
||||
from typing import Union, Any, List
|
||||
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
|
||||
from type.astrbot_message import AstrBotMessage
|
||||
from type.command import CommandResult
|
||||
|
||||
|
||||
class Platform():
|
||||
@@ -24,7 +25,7 @@ class Platform():
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send_msg(self, target: Any, result_message: Union[List[BaseMessageComponent], str]):
|
||||
async def send_msg(self, target: Any, result_message: CommandResult):
|
||||
'''
|
||||
发送消息(主动)
|
||||
'''
|
||||
|
||||
@@ -7,6 +7,7 @@ from aiocqhttp.exceptions import ActionFailed
|
||||
from . import Platform
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
@@ -165,7 +166,7 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
await self._reply(message, res)
|
||||
|
||||
async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]):
|
||||
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
|
||||
if isinstance(message_chain, str):
|
||||
message_chain = [Plain(text=message_chain), ]
|
||||
|
||||
@@ -179,7 +180,15 @@ class AIOCQHTTP(Platform):
|
||||
image_idx.append(idx)
|
||||
ret.append(d)
|
||||
try:
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
if isinstance(message, AstrBotMessage):
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
if isinstance(message, dict):
|
||||
if 'group_id' in message:
|
||||
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
|
||||
elif 'user_id' in message:
|
||||
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
|
||||
else:
|
||||
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
|
||||
except ActionFailed as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
@@ -195,4 +204,17 @@ class AIOCQHTTP(Platform):
|
||||
logger.info(f"上传成功。")
|
||||
ret[idx]['data']['file'] = image_url
|
||||
ret[idx]['data']['path'] = image_url
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
await self.bot.send(message.raw_message, ret)
|
||||
|
||||
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给QQ用户、QQ群发送一条消息。
|
||||
|
||||
`target` 接收一个 dict 类型的值引用。
|
||||
|
||||
- 要发给 QQ 下的某个用户,请添加 key `user_id`,值为 int 类型的 qq 号;
|
||||
- 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号;
|
||||
|
||||
'''
|
||||
|
||||
await self._reply(target, result_message.message_chain)
|
||||
@@ -14,6 +14,7 @@ from type.types import Context
|
||||
from . import Platform
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
@@ -199,7 +200,7 @@ class QQGOCQ(Platform):
|
||||
return
|
||||
await self.client.sendGroupMessage(group_id, message_chain)
|
||||
|
||||
async def send_msg(self, target: Dict[str, int], result_message: Union[List[BaseMessageComponent], str]):
|
||||
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给用户、群或者频道发送一条消息。
|
||||
|
||||
@@ -211,7 +212,7 @@ class QQGOCQ(Platform):
|
||||
|
||||
guild_id 不是频道号。
|
||||
'''
|
||||
await self._reply(target, result_message)
|
||||
await self._reply(target, result_message.message_chain)
|
||||
|
||||
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
|
||||
abm = AstrBotMessage()
|
||||
|
||||
@@ -13,6 +13,7 @@ from util.io import save_temp_img, download_image_by_url
|
||||
from . import Platform
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from type.command import *
|
||||
from typing import Union, List, Dict
|
||||
from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
@@ -42,11 +43,16 @@ class botClient(Client):
|
||||
# 转换层
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
await self.platform.handle_msg(abm)
|
||||
|
||||
# 收到 C2C 消息
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
await self.platform.handle_msg(abm)
|
||||
|
||||
|
||||
class QQOfficial(Platform):
|
||||
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
|
||||
super().__init__()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
@@ -80,6 +86,8 @@ class QQOfficial(Platform):
|
||||
|
||||
self.client.set_platform(self)
|
||||
|
||||
self.test_mode = test_mode
|
||||
|
||||
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
|
||||
plain_text = ""
|
||||
image_path = None # only one img supported
|
||||
@@ -107,11 +115,17 @@ class QQOfficial(Platform):
|
||||
abm.tag = "qqchan"
|
||||
msg: List[BaseMessageComponent] = []
|
||||
|
||||
if message_type == MessageType.GROUP_MESSAGE:
|
||||
abm.sender = MessageMember(
|
||||
message.author.member_openid,
|
||||
""
|
||||
)
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
|
||||
if isinstance(message, botpy.message.GroupMessage):
|
||||
abm.sender = MessageMember(
|
||||
message.author.member_openid,
|
||||
""
|
||||
)
|
||||
else:
|
||||
abm.sender = MessageMember(
|
||||
message.author.user_openid,
|
||||
""
|
||||
)
|
||||
abm.message_str = message.content.strip()
|
||||
abm.self_id = "unknown_selfid"
|
||||
|
||||
@@ -126,8 +140,7 @@ class QQOfficial(Platform):
|
||||
msg.append(img)
|
||||
abm.message = msg
|
||||
|
||||
elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE:
|
||||
# 目前对于 FRIEND_MESSAGE 只处理频道私聊
|
||||
elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage):
|
||||
try:
|
||||
abm.self_id = str(message.mentions[0].id)
|
||||
except:
|
||||
@@ -175,7 +188,7 @@ class QQOfficial(Platform):
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
assert isinstance(message.raw_message, (botpy.message.Message,
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage))
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
is_group = message.type != MessageType.FRIEND_MESSAGE
|
||||
|
||||
_t = "/私聊" if not is_group else ""
|
||||
@@ -209,13 +222,15 @@ class QQOfficial(Platform):
|
||||
if not message_result:
|
||||
return
|
||||
|
||||
await self.reply_msg(message, message_result.result_message)
|
||||
ret = await self.reply_msg(message, message_result.result_message)
|
||||
if message_result.callback:
|
||||
message_result.callback()
|
||||
|
||||
# 如果是等待回复的消息
|
||||
if session_id in self.waiting and self.waiting[session_id] == '':
|
||||
self.waiting[session_id] = message
|
||||
|
||||
return ret
|
||||
|
||||
async def reply_msg(self,
|
||||
message: AstrBotMessage,
|
||||
@@ -225,7 +240,7 @@ class QQOfficial(Platform):
|
||||
'''
|
||||
source = message.raw_message
|
||||
assert isinstance(source, (botpy.message.Message,
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage))
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
logger.info(
|
||||
f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}")
|
||||
|
||||
@@ -253,12 +268,14 @@ class QQOfficial(Platform):
|
||||
'message_reference': msg_ref
|
||||
}
|
||||
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if isinstance(message.raw_message, botpy.message.GroupMessage):
|
||||
data['group_openid'] = str(source.group_openid)
|
||||
elif message.type == MessageType.GUILD_MESSAGE:
|
||||
elif isinstance(message.raw_message, botpy.message.Message):
|
||||
data['channel_id'] = source.channel_id
|
||||
elif message.type == MessageType.FRIEND_MESSAGE:
|
||||
elif isinstance(message.raw_message, botpy.message.DirectMessage):
|
||||
data['guild_id'] = source.guild_id
|
||||
elif isinstance(message.raw_message, botpy.message.C2CMessage):
|
||||
data['openid'] = source.author.user_openid
|
||||
if image_path:
|
||||
data['file_image'] = image_path
|
||||
if rendered_images:
|
||||
@@ -269,14 +286,13 @@ class QQOfficial(Platform):
|
||||
_data['message_reference'] = None
|
||||
|
||||
try:
|
||||
await self._reply(**_data)
|
||||
return
|
||||
return await self._reply(**_data)
|
||||
except BaseException as e:
|
||||
logger.warn(traceback.format_exc())
|
||||
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
|
||||
|
||||
try:
|
||||
await self._reply(**data)
|
||||
return await self._reply(**data)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
# 分割过长的消息
|
||||
@@ -286,28 +302,27 @@ class QQOfficial(Platform):
|
||||
split_res.append(plain_text[len(plain_text)//2:])
|
||||
for i in split_res:
|
||||
data['content'] = i
|
||||
await self._reply(**data)
|
||||
return await self._reply(**data)
|
||||
else:
|
||||
try:
|
||||
# 防止被qq频道过滤消息
|
||||
plain_text = plain_text.replace(".", " . ")
|
||||
await self._reply(**data)
|
||||
return await self._reply(**data)
|
||||
except BaseException as e:
|
||||
try:
|
||||
data['content'] = str.join(" ", plain_text)
|
||||
await self._reply(**data)
|
||||
return await self._reply(**data)
|
||||
except BaseException as e:
|
||||
plain_text = re.sub(
|
||||
r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
|
||||
plain_text = plain_text.replace(".", "·")
|
||||
data['content'] = plain_text
|
||||
await self._reply(**data)
|
||||
return await self._reply(**data)
|
||||
|
||||
async def _reply(self, **kwargs):
|
||||
if 'group_openid' in kwargs:
|
||||
if 'group_openid' in kwargs or 'openid' in kwargs:
|
||||
# QQ群组消息
|
||||
# qq群组消息需要自行上传,暂时不处理
|
||||
if 'file_image' in kwargs:
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
file_image_path = kwargs['file_image'].replace("file:///", "")
|
||||
if file_image_path:
|
||||
|
||||
@@ -317,48 +332,61 @@ class QQOfficial(Platform):
|
||||
logger.debug(f"上传图片: {file_image_path}")
|
||||
image_url = await self.context.image_uploader.upload_image(file_image_path)
|
||||
logger.debug(f"上传成功: {image_url}")
|
||||
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
|
||||
if 'group_openid' in kwargs:
|
||||
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
|
||||
elif 'openid' in kwargs:
|
||||
media = await self.client.api.post_c2c_file(kwargs['openid'], 1, image_url)
|
||||
del kwargs['file_image']
|
||||
kwargs['media'] = media
|
||||
logger.debug(f"发送群图片: {media}")
|
||||
kwargs['msg_type'] = 7 # 富媒体
|
||||
await self.client.api.post_group_message(**kwargs)
|
||||
if self.test_mode:
|
||||
return kwargs
|
||||
if 'group_openid' in kwargs:
|
||||
await self.client.api.post_group_message(**kwargs)
|
||||
elif 'openid' in kwargs:
|
||||
await self.client.api.post_c2c_message(**kwargs)
|
||||
elif 'channel_id' in kwargs:
|
||||
# 频道消息
|
||||
if 'file_image' in kwargs:
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
|
||||
# 频道消息发图只支持本地
|
||||
if kwargs['file_image'].startswith("http"):
|
||||
kwargs['file_image'] = await download_image_by_url(kwargs['file_image'])
|
||||
if self.test_mode:
|
||||
return kwargs
|
||||
await self.client.api.post_message(**kwargs)
|
||||
else:
|
||||
elif 'guild_id' in kwargs:
|
||||
# 频道私聊消息
|
||||
if 'file_image' in kwargs:
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
|
||||
if kwargs['file_image'].startswith("http"):
|
||||
kwargs['file_image'] = await download_image_by_url(kwargs['file_image'])
|
||||
if self.test_mode:
|
||||
return kwargs
|
||||
await self.client.api.post_dms(**kwargs)
|
||||
else:
|
||||
raise ValueError("Unknown target type.")
|
||||
|
||||
async def send_msg(self, target: Dict[str, str], result_message: Union[List[BaseMessageComponent], str]):
|
||||
async def send_msg(self, target: Dict[str, str], result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给用户、群或者频道发送一条消息。
|
||||
以主动的方式给频道用户、群、频道或者消息列表用户(QQ用户)发送一条消息。
|
||||
|
||||
`target` 接收一个 dict 类型的值引用。
|
||||
|
||||
- 如果目标是 QQ 群,请添加 key `group_openid`。
|
||||
- 如果目标是 频道消息,请添加 key `channel_id`。
|
||||
- 如果目标是 频道私聊,请添加 key `guild_id`。
|
||||
- 如果目标是 QQ 用户,请添加 key `openid`。
|
||||
'''
|
||||
if isinstance(result_message, list):
|
||||
plain_text, image_path = await self._parse_to_qqofficial(result_message)
|
||||
else:
|
||||
plain_text = result_message
|
||||
plain_text, image_path = await self._parse_to_qqofficial(result_message.message_chain)
|
||||
|
||||
payload = {
|
||||
'content': plain_text,
|
||||
'file_image': image_path,
|
||||
**target
|
||||
}
|
||||
if image_path:
|
||||
payload['file_image'] = image_path
|
||||
await self._reply(**payload)
|
||||
|
||||
def wait_for_message(self, channel_id: int) -> AstrBotMessage:
|
||||
|
||||
@@ -24,6 +24,7 @@ class CommandResult():
|
||||
self.success = success
|
||||
self.message_chain = message_chain
|
||||
self.command_name = command_name
|
||||
self.is_use_t2i = None # default
|
||||
|
||||
def message(self, message: str):
|
||||
'''
|
||||
@@ -61,6 +62,15 @@ class CommandResult():
|
||||
'''
|
||||
self.message_chain = [Image.fromFileSystem(path), ]
|
||||
return self
|
||||
|
||||
# def use_t2i(self, use_t2i: bool):
|
||||
# '''
|
||||
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
|
||||
|
||||
# CommandResult().use_t2i(False)
|
||||
# '''
|
||||
# self.is_use_t2i = use_t2i
|
||||
# return self
|
||||
|
||||
def _result_tuple(self):
|
||||
return (self.success, self.message_chain, self.command_name)
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
VERSION = '3.3.0'
|
||||
VERSION = '3.3.4'
|
||||
@@ -43,13 +43,10 @@ class AstrMessageEvent():
|
||||
context,
|
||||
session_id)
|
||||
return ame
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageResult():
|
||||
result_message: Union[str, list]
|
||||
is_command_call: Optional[bool] = False
|
||||
use_t2i: Optional[bool] = None # None 为跟随用户设置
|
||||
callback: Optional[callable] = None
|
||||
|
||||
@@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.image_uploader import ImageUploader
|
||||
from util.updator.plugin_updator import PluginUpdator
|
||||
from model.plugin.command import PluginCommandBridge
|
||||
from model.provider.provider import Provider
|
||||
|
||||
|
||||
class Context:
|
||||
@@ -68,6 +69,14 @@ class Context:
|
||||
'''
|
||||
task = asyncio.create_task(coro, name=task_name)
|
||||
self.ext_tasks.append(task)
|
||||
|
||||
def register_provider(self, llm_name: str, provider: Provider, origin: str = ''):
|
||||
'''
|
||||
注册一个提供 LLM 资源的 Provider。
|
||||
|
||||
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
|
||||
'''
|
||||
self.llms.append(RegisteredLLM(llm_name, provider, origin))
|
||||
|
||||
def find_platform(self, platform_name: str) -> RegisteredPlatform:
|
||||
for platform in self.platforms:
|
||||
|
||||
@@ -122,7 +122,7 @@ class FuncCall():
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = self.provider.text_chat(prompt, session_id)
|
||||
res = self.provider.text_chat(prompt=prompt, session_id=session_id)
|
||||
if res.find('```') != -1:
|
||||
res = res[res.find('```json') + 7: res.rfind('```')]
|
||||
gu.log("REVGPT func_call json result",
|
||||
@@ -187,7 +187,7 @@ class FuncCall():
|
||||
_c = 0
|
||||
while _c < 5:
|
||||
try:
|
||||
res = self.provider.text_chat(after_prompt, session_id)
|
||||
res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
|
||||
# 截取```之间的内容
|
||||
gu.log(
|
||||
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
|
||||
|
||||
@@ -127,7 +127,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
|
||||
function_invoked_ret = ""
|
||||
if official_fc:
|
||||
# we use official function-calling
|
||||
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func())
|
||||
result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
|
||||
if isinstance(result, Function):
|
||||
logger.debug(f"web_searcher - function-calling: {result}")
|
||||
func_obj = None
|
||||
@@ -136,14 +136,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
|
||||
func_obj = i["func_obj"]
|
||||
break
|
||||
if not func_obj:
|
||||
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
try:
|
||||
args = json.loads(result.arguments)
|
||||
function_invoked_ret = await func_obj(**args)
|
||||
has_func = True
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)"
|
||||
return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
|
||||
else:
|
||||
return result
|
||||
else:
|
||||
@@ -162,7 +162,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
|
||||
has_func = True
|
||||
|
||||
if has_func:
|
||||
await provider.forget(session_id)
|
||||
await provider.forget(session_id=session_id, )
|
||||
summary_prompt = f"""
|
||||
你是一个专业且高效的助手,你的任务是
|
||||
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
|
||||
@@ -178,6 +178,6 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
|
||||
|
||||
# 相关材料
|
||||
{function_invoked_ret}"""
|
||||
ret = await provider.text_chat(summary_prompt, session_id)
|
||||
ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
|
||||
return ret
|
||||
return function_invoked_ret
|
||||
|
||||
@@ -8,4 +8,5 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。
|
||||
from model.platform import Platform
|
||||
|
||||
from model.platform.qq_nakuru import QQGOCQ
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
@@ -32,7 +32,7 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
"options": {
|
||||
"full_page": True,
|
||||
"type": "jpeg",
|
||||
"quality": 25,
|
||||
"quality": 40,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from type.config import VERSION
|
||||
from util.io import on_error
|
||||
from util.io import on_error, download_file
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -58,7 +58,8 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
raise Exception(f"未找到版本号为 {version} 的更新文件。")
|
||||
|
||||
try:
|
||||
self.download_from_repo_url("temp", data['zipball_url'])
|
||||
# self.download_from_repo_url("temp", file_url)
|
||||
download_file(file_url, "temp.zip")
|
||||
self.unzip_file("temp.zip", self.MAIN_PATH)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user