Compare commits

...

11 Commits

Author SHA1 Message Date
Soulter b241b0f954 update version 2024-07-27 12:31:15 -04:00
Soulter 171dd1dc02 feat: qq 官方机器人接口支持C2C 2024-07-27 12:30:09 -04:00
Soulter af62d969d7 perf: 更改 send_msg 接口 2024-07-27 11:26:02 -04:00
Soulter c4fd9a66c6 update version to 3.3.3 2024-07-27 11:08:51 -04:00
Soulter d191997a39 feat: aiocqhttp 适配器适配主动发送消息接口 2024-07-27 11:07:26 -04:00
Soulter 853ac4c104 fix: 优化 update 提示 2024-07-27 04:58:15 -04:00
Soulter ed053acad6 update: version 2024-07-27 04:47:57 -04:00
Soulter f147634e51 fix: 修复update异常 2024-07-27 04:43:53 -04:00
Soulter e3b2a68341 Merge pull request #179 from Soulter/refactor-v3.3.0
feat: 新增 Provider 注册接口;新增 provider 指令
2024-07-27 16:31:03 +08:00
Soulter 84c450aef9 feat: 新增 Provider 注册接口;新增 provider 指令 2024-07-27 04:25:27 -04:00
Soulter f52a0eb43a fix: 修复默认配置迁移问题 2024-07-27 08:58:26 +08:00
21 changed files with 214 additions and 202 deletions
+1
View File
@@ -10,3 +10,4 @@ cmd_config.json
data/* data/*
cookies.json cookies.json
logs/ logs/
addons/plugins
+5 -1
View File
@@ -101,10 +101,14 @@ class AstrBotBootstrap():
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
self.llm_instance = ProviderOpenAIOfficial(self.context) self.llm_instance = ProviderOpenAIOfficial(self.context)
self.openai_command_handler.set_provider(self.llm_instance) self.openai_command_handler.set_provider(self.llm_instance)
self.context.register_provider("internal_openai", self.llm_instance)
logger.info("已启用 OpenAI API 支持。") logger.info("已启用 OpenAI API 支持。")
def load_plugins(self): def load_plugins(self):
self.plugin_manager.plugin_reload() self.plugin_manager.plugin_reload()
def load_platform(self): def load_platform(self):
return self.platfrom_manager.load_platforms() platforms = self.platfrom_manager.load_platforms()
if not platforms:
logger.warn("未启用任何消息平台。")
return platforms
+8 -2
View File
@@ -115,6 +115,9 @@ class MessageHandler():
self.nicks = self.context.nick self.nicks = self.context.nick
self.provider = provider self.provider = provider
self.reply_prefix = self.context.reply_prefix self.reply_prefix = self.context.reply_prefix
def set_provider(self, provider: Provider):
self.provider = provider
async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult: async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult:
''' '''
@@ -148,7 +151,8 @@ class MessageHandler():
assert(isinstance(cmd_res, CommandResult)) assert(isinstance(cmd_res, CommandResult))
return MessageResult( return MessageResult(
cmd_res.message_chain, cmd_res.message_chain,
is_command_call=True is_command_call=True,
use_t2i=cmd_res.is_use_t2i
) )
# check if the message is a llm-wake-up command # check if the message is a llm-wake-up command
@@ -178,7 +182,9 @@ class MessageHandler():
llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider) llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider)
else: else:
llm_result = await provider.text_chat( llm_result = await provider.text_chat(
msg_plain, message.session_id, image_url prompt=msg_plain,
session_id=message.session_id,
image_url=image_url
) )
except BaseException as e: except BaseException as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
-137
View File
@@ -1,137 +0,0 @@
# 如果你不知道怎么部署,请查看https://soulter.top/posts/qpdg.html
# 不一定需要key了,如果你没有key但有openAI账号或者必应账号,可以考虑使用下面的逆向库
###############平台设置#################
# QQ频道机器人
# QQ开放平台的appid和令牌
# q.qq.com
# enable为true则启用,false则不启用
qqbot:
enable: true
appid:
token:
# QQ机器人
# enable为true则启用,false则不启用
# 需要安装GO-CQHTTP配合使用。
# 文档:https://docs.go-cqhttp.org/
# 请将go-cqhttp的配置文件的sever部分粘贴为以下内容,否则无法使用
# 请先启动go-cqhttp再启动本程序
#
# servers:
# - http:
# host: 127.0.0.1
# version: 0
# port: 5700
# timeout: 5
# - ws:
# address: 127.0.0.1:6700
# middlewares:
# <<: *default
gocqbot:
enable: false
# 设置是否一个人一个会话
uniqueSessionMode: false
# QChannelBot 的版本,请勿修改此字段,否则可能产生一些bug
version: 3.0
# [Beta] 转储历史记录时间间隔(分钟)
dump_history_interval: 10
# 一个用户只能在time秒内发送count条消息
limit:
time: 60
count: 5
# 公告
notice: "此机器人由Github项目QQChannelChatGPT驱动。"
# 是否打开私信功能
# 设置为true则频道成员可以私聊机器人。
# 设置为false则频道成员不能私聊机器人。
direct_message_mode: true
# 系统代理
# http_proxy: http://localhost:7890
# https_proxy: http://localhost:7890
# 自定义回复前缀,如[Rev]或其他,务必加引号以防止不必要的bug。
reply_prefix:
openai_official: "[GPT]"
rev_chatgpt: "[Rev]"
rev_edgegpt: "[RevBing]"
# 百度内容审核服务
# 新用户免费5万次调用。https://cloud.baidu.com/doc/ANTIPORN/index.html
baidu_aip:
enable: false
app_id:
api_key:
secret_key:
###############语言模型设置#################
# OpenAI官方API
# 注意:已支持多key自动切换,方法:
# key:
# - sk-xxxxxx
# - sk-xxxxxx
# 在下方非注释的地方使用以上格式
# 关于api_base:可以使用一些云函数(如腾讯、阿里)来避免国内被墙的问题。
# 详见:
# https://github.com/Ice-Hazymoon/openai-scf-proxy
# https://github.com/Soulter/QQChannelChatGPT/issues/42
# 设置为none则表示使用官方默认api地址
openai:
key:
-
api_base: none
# 这里是GPT配置,语言模型默认使用gpt-3.5-turbo
chatGPTConfigs:
model: gpt-3.5-turbo
max_tokens: 3000
temperature: 0.9
top_p: 1
frequency_penalty: 0
presence_penalty: 0
total_tokens_limit: 5000
# 逆向文心一言【暂时不可用,请勿使用】
rev_ernie:
enable: false
# 逆向New Bing
# 需要在项目根目录下创建cookies.json并粘贴cookies进去。
# 详见:https://soulter.top/posts/qpdg.html
rev_edgegpt:
enable: false
# 逆向ChatGPT库
# https://github.com/acheong08/ChatGPT
# 优点:免费(无免费额度限制);
# 缺点:速度相对慢。OpenAI 速率限制:免费帐户每小时 50 个请求。您可以通过多帐户循环来绕过它
# enable设置为true后,将会停止使用上面正常的官方API调用而使用本逆向项目
#
# 多账户可以保证每个请求都能得到及时的回复。
# 关于account的格式
# account:
# - email: 第1个账户
# password: 第1个账户密码
# - email: 第2个账户
# password: 第2个账户密码
# - ....
# 支持使用access_token登录
# 例:
# - session_token: xxxxx
# - access_token: xxxx
# 请严格按照上面这个格式填写。
# 逆向ChatGPT库的email-password登录方式不工作,建议使用access_token登录
# 获取access_token的方法,详见:https://soulter.top/posts/qpdg.html
rev_ChatGPT:
enable: false
account:
- access_token:
+1 -1
View File
@@ -285,7 +285,7 @@ class AstrBotDashBoard():
ret = self.astrbot_updator.check_update(None, None) ret = self.astrbot_updator.check_update(None, None)
return Response( return Response(
status="success", status="success",
message=str(ret), message=str(ret) if ret is not None else "已经是最新版本了。",
data={ data={
"has_new_version": ret is not None "has_new_version": ret is not None
} }
+30 -1
View File
@@ -26,7 +26,36 @@ class InternalCommandHandler:
self.manager.register("websearch", "网页搜索开关", 10, self.web_search) self.manager.register("websearch", "网页搜索开关", 10, self.web_search)
self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle) self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle)
self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid) self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid)
self.manager.register("provider", "查看和切换当前使用的 LLM 资源来源", 10, self.provider)
def provider(self, message: AstrMessageEvent, context: Context):
if len(context.llms) == 0:
return CommandResult().message("当前没有加载任何 LLM 资源。")
tokens = self.manager.command_parser.parse(message.message_str)
if tokens.len == 1:
ret = "## 当前载入的 LLM 资源\n"
for idx, llm in enumerate(context.llms):
ret += f"{idx}. {llm.llm_name}"
if llm.origin:
ret += f" (来源: {llm.origin})"
if context.message_handler.provider == llm.llm_instance:
ret += " (当前使用)"
ret += "\n"
ret += "\n使用 provider <序号> 切换 LLM 资源。"
return CommandResult().message(ret)
else:
try:
idx = int(tokens.get(1))
if idx >= len(context.llms):
return CommandResult().message("provider: 无效的序号。")
context.message_handler.set_provider(context.llms[idx].llm_instance)
return CommandResult().message(f"已经成功切换到 LLM 资源 {context.llms[idx].llm_name}")
except BaseException as e:
return CommandResult().message("provider: 参数错误。")
def set_nick(self, message: AstrMessageEvent, context: Context): def set_nick(self, message: AstrMessageEvent, context: Context):
message_str = message.message_str message_str = message.message_str
if message.role != "admin": if message.role != "admin":
+2 -1
View File
@@ -73,7 +73,8 @@ class CommandManager():
if message_str.startswith(command): if message_str.startswith(command):
logger.info(f"触发 {command} 指令。") logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context) command_result = await self.execute_handler(command, message_event, context)
return command_result if command_result.hit:
return command_result
async def execute_handler(self, async def execute_handler(self,
command: str, command: str,
+2 -1
View File
@@ -2,6 +2,7 @@ import abc
from typing import Union, Any, List from typing import Union, Any, List
from nakuru.entities.components import Plain, At, Image, BaseMessageComponent from nakuru.entities.components import Plain, At, Image, BaseMessageComponent
from type.astrbot_message import AstrBotMessage from type.astrbot_message import AstrBotMessage
from type.command import CommandResult
class Platform(): class Platform():
@@ -24,7 +25,7 @@ class Platform():
pass pass
@abc.abstractmethod @abc.abstractmethod
async def send_msg(self, target: Any, result_message: Union[List[BaseMessageComponent], str]): async def send_msg(self, target: Any, result_message: CommandResult):
''' '''
发送消息(主动) 发送消息(主动)
''' '''
+25 -3
View File
@@ -7,6 +7,7 @@ from aiocqhttp.exceptions import ActionFailed
from . import Platform from . import Platform
from type.astrbot_message import * from type.astrbot_message import *
from type.message_event import * from type.message_event import *
from type.command import *
from typing import Union, List, Dict from typing import Union, List, Dict
from nakuru.entities.components import * from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
@@ -165,7 +166,7 @@ class AIOCQHTTP(Platform):
await self._reply(message, res) await self._reply(message, res)
async def _reply(self, message: AstrBotMessage, message_chain: List[BaseMessageComponent]): async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
if isinstance(message_chain, str): if isinstance(message_chain, str):
message_chain = [Plain(text=message_chain), ] message_chain = [Plain(text=message_chain), ]
@@ -179,7 +180,15 @@ class AIOCQHTTP(Platform):
image_idx.append(idx) image_idx.append(idx)
ret.append(d) ret.append(d)
try: try:
await self.bot.send(message.raw_message, ret) if isinstance(message, AstrBotMessage):
await self.bot.send(message.raw_message, ret)
if isinstance(message, dict):
if 'group_id' in message:
await self.bot.send_group_msg(group_id=message['group_id'], message=ret)
elif 'user_id' in message:
await self.bot.send_private_msg(user_id=message['user_id'], message=ret)
else:
raise Exception("aiocqhttp: 无法识别的消息来源。仅支持 group_id 和 user_id。")
except ActionFailed as e: except ActionFailed as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"回复消息失败: {e}") logger.error(f"回复消息失败: {e}")
@@ -195,4 +204,17 @@ class AIOCQHTTP(Platform):
logger.info(f"上传成功。") logger.info(f"上传成功。")
ret[idx]['data']['file'] = image_url ret[idx]['data']['file'] = image_url
ret[idx]['data']['path'] = image_url ret[idx]['data']['path'] = image_url
await self.bot.send(message.raw_message, ret) await self.bot.send(message.raw_message, ret)
async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
'''
以主动的方式给QQ用户、QQ群发送一条消息。
`target` 接收一个 dict 类型的值引用。
- 要发给 QQ 下的某个用户,请添加 key `user_id`,值为 int 类型的 qq 号;
- 要发给某个群聊,请添加 key `group_id`,值为 int 类型的 qq 群号;
'''
await self._reply(target, result_message.message_chain)
+3 -2
View File
@@ -14,6 +14,7 @@ from type.types import Context
from . import Platform from . import Platform
from type.astrbot_message import * from type.astrbot_message import *
from type.message_event import * from type.message_event import *
from type.command import *
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Logger from logging import Logger
from astrbot.message.handler import MessageHandler from astrbot.message.handler import MessageHandler
@@ -199,7 +200,7 @@ class QQGOCQ(Platform):
return return
await self.client.sendGroupMessage(group_id, message_chain) await self.client.sendGroupMessage(group_id, message_chain)
async def send_msg(self, target: Dict[str, int], result_message: Union[List[BaseMessageComponent], str]): async def send_msg(self, target: Dict[str, int], result_message: CommandResult):
''' '''
以主动的方式给用户、群或者频道发送一条消息。 以主动的方式给用户、群或者频道发送一条消息。
@@ -211,7 +212,7 @@ class QQGOCQ(Platform):
guild_id 不是频道号。 guild_id 不是频道号。
''' '''
await self._reply(target, result_message) await self._reply(target, result_message.message_chain)
def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage: def convert_message(self, message: Union[GroupMessage, FriendMessage, GuildMessage]) -> AstrBotMessage:
abm = AstrBotMessage() abm = AstrBotMessage()
+64 -36
View File
@@ -13,6 +13,7 @@ from util.io import save_temp_img, download_image_by_url
from . import Platform from . import Platform
from type.astrbot_message import * from type.astrbot_message import *
from type.message_event import * from type.message_event import *
from type.command import *
from typing import Union, List, Dict from typing import Union, List, Dict
from nakuru.entities.components import * from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
@@ -42,11 +43,16 @@ class botClient(Client):
# 转换层 # 转换层
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE) abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
await self.platform.handle_msg(abm) await self.platform.handle_msg(abm)
# 收到 C2C 消息
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
await self.platform.handle_msg(abm)
class QQOfficial(Platform): class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None: def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
super().__init__() super().__init__()
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
@@ -80,6 +86,8 @@ class QQOfficial(Platform):
self.client.set_platform(self) self.client.set_platform(self)
self.test_mode = test_mode
async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False): async def _parse_to_qqofficial(self, message: List[BaseMessageComponent], is_group: bool = False):
plain_text = "" plain_text = ""
image_path = None # only one img supported image_path = None # only one img supported
@@ -107,11 +115,17 @@ class QQOfficial(Platform):
abm.tag = "qqchan" abm.tag = "qqchan"
msg: List[BaseMessageComponent] = [] msg: List[BaseMessageComponent] = []
if message_type == MessageType.GROUP_MESSAGE: if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
abm.sender = MessageMember( if isinstance(message, botpy.message.GroupMessage):
message.author.member_openid, abm.sender = MessageMember(
"" message.author.member_openid,
) ""
)
else:
abm.sender = MessageMember(
message.author.user_openid,
""
)
abm.message_str = message.content.strip() abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid" abm.self_id = "unknown_selfid"
@@ -126,8 +140,7 @@ class QQOfficial(Platform):
msg.append(img) msg.append(img)
abm.message = msg abm.message = msg
elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE: elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage):
# 目前对于 FRIEND_MESSAGE 只处理频道私聊
try: try:
abm.self_id = str(message.mentions[0].id) abm.self_id = str(message.mentions[0].id)
except: except:
@@ -175,7 +188,7 @@ class QQOfficial(Platform):
async def handle_msg(self, message: AstrBotMessage): async def handle_msg(self, message: AstrBotMessage):
assert isinstance(message.raw_message, (botpy.message.Message, assert isinstance(message.raw_message, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage)) botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
is_group = message.type != MessageType.FRIEND_MESSAGE is_group = message.type != MessageType.FRIEND_MESSAGE
_t = "/私聊" if not is_group else "" _t = "/私聊" if not is_group else ""
@@ -209,13 +222,15 @@ class QQOfficial(Platform):
if not message_result: if not message_result:
return return
await self.reply_msg(message, message_result.result_message) ret = await self.reply_msg(message, message_result.result_message)
if message_result.callback: if message_result.callback:
message_result.callback() message_result.callback()
# 如果是等待回复的消息 # 如果是等待回复的消息
if session_id in self.waiting and self.waiting[session_id] == '': if session_id in self.waiting and self.waiting[session_id] == '':
self.waiting[session_id] = message self.waiting[session_id] = message
return ret
async def reply_msg(self, async def reply_msg(self,
message: AstrBotMessage, message: AstrBotMessage,
@@ -225,7 +240,7 @@ class QQOfficial(Platform):
''' '''
source = message.raw_message source = message.raw_message
assert isinstance(source, (botpy.message.Message, assert isinstance(source, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage)) botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
logger.info( logger.info(
f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}") f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}")
@@ -253,12 +268,14 @@ class QQOfficial(Platform):
'message_reference': msg_ref 'message_reference': msg_ref
} }
if message.type == MessageType.GROUP_MESSAGE: if isinstance(message.raw_message, botpy.message.GroupMessage):
data['group_openid'] = str(source.group_openid) data['group_openid'] = str(source.group_openid)
elif message.type == MessageType.GUILD_MESSAGE: elif isinstance(message.raw_message, botpy.message.Message):
data['channel_id'] = source.channel_id data['channel_id'] = source.channel_id
elif message.type == MessageType.FRIEND_MESSAGE: elif isinstance(message.raw_message, botpy.message.DirectMessage):
data['guild_id'] = source.guild_id data['guild_id'] = source.guild_id
elif isinstance(message.raw_message, botpy.message.C2CMessage):
data['openid'] = source.author.user_openid
if image_path: if image_path:
data['file_image'] = image_path data['file_image'] = image_path
if rendered_images: if rendered_images:
@@ -269,14 +286,13 @@ class QQOfficial(Platform):
_data['message_reference'] = None _data['message_reference'] = None
try: try:
await self._reply(**_data) return await self._reply(**_data)
return
except BaseException as e: except BaseException as e:
logger.warn(traceback.format_exc()) logger.warn(traceback.format_exc())
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
try: try:
await self._reply(**data) return await self._reply(**data)
except BaseException as e: except BaseException as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
# 分割过长的消息 # 分割过长的消息
@@ -286,28 +302,27 @@ class QQOfficial(Platform):
split_res.append(plain_text[len(plain_text)//2:]) split_res.append(plain_text[len(plain_text)//2:])
for i in split_res: for i in split_res:
data['content'] = i data['content'] = i
await self._reply(**data) return await self._reply(**data)
else: else:
try: try:
# 防止被qq频道过滤消息 # 防止被qq频道过滤消息
plain_text = plain_text.replace(".", " . ") plain_text = plain_text.replace(".", " . ")
await self._reply(**data) return await self._reply(**data)
except BaseException as e: except BaseException as e:
try: try:
data['content'] = str.join(" ", plain_text) data['content'] = str.join(" ", plain_text)
await self._reply(**data) return await self._reply(**data)
except BaseException as e: except BaseException as e:
plain_text = re.sub( plain_text = re.sub(
r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE) r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '[被隐藏的链接]', str(e), flags=re.MULTILINE)
plain_text = plain_text.replace(".", "·") plain_text = plain_text.replace(".", "·")
data['content'] = plain_text data['content'] = plain_text
await self._reply(**data) return await self._reply(**data)
async def _reply(self, **kwargs): async def _reply(self, **kwargs):
if 'group_openid' in kwargs: if 'group_openid' in kwargs or 'openid' in kwargs:
# QQ群组消息 # QQ群组消息
# qq群组消息需要自行上传,暂时不处理 if 'file_image' in kwargs and kwargs['file_image']:
if 'file_image' in kwargs:
file_image_path = kwargs['file_image'].replace("file:///", "") file_image_path = kwargs['file_image'].replace("file:///", "")
if file_image_path: if file_image_path:
@@ -317,48 +332,61 @@ class QQOfficial(Platform):
logger.debug(f"上传图片: {file_image_path}") logger.debug(f"上传图片: {file_image_path}")
image_url = await self.context.image_uploader.upload_image(file_image_path) image_url = await self.context.image_uploader.upload_image(file_image_path)
logger.debug(f"上传成功: {image_url}") logger.debug(f"上传成功: {image_url}")
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url) if 'group_openid' in kwargs:
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
elif 'openid' in kwargs:
media = await self.client.api.post_c2c_file(kwargs['openid'], 1, image_url)
del kwargs['file_image'] del kwargs['file_image']
kwargs['media'] = media kwargs['media'] = media
logger.debug(f"发送群图片: {media}") logger.debug(f"发送群图片: {media}")
kwargs['msg_type'] = 7 # 富媒体 kwargs['msg_type'] = 7 # 富媒体
await self.client.api.post_group_message(**kwargs) if self.test_mode:
return kwargs
if 'group_openid' in kwargs:
await self.client.api.post_group_message(**kwargs)
elif 'openid' in kwargs:
await self.client.api.post_c2c_message(**kwargs)
elif 'channel_id' in kwargs: elif 'channel_id' in kwargs:
# 频道消息 # 频道消息
if 'file_image' in kwargs: if 'file_image' in kwargs and kwargs['file_image']:
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
# 频道消息发图只支持本地 # 频道消息发图只支持本地
if kwargs['file_image'].startswith("http"): if kwargs['file_image'].startswith("http"):
kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) kwargs['file_image'] = await download_image_by_url(kwargs['file_image'])
if self.test_mode:
return kwargs
await self.client.api.post_message(**kwargs) await self.client.api.post_message(**kwargs)
else: elif 'guild_id' in kwargs:
# 频道私聊消息 # 频道私聊消息
if 'file_image' in kwargs: if 'file_image' in kwargs and kwargs['file_image']:
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "") kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
if kwargs['file_image'].startswith("http"): if kwargs['file_image'].startswith("http"):
kwargs['file_image'] = await download_image_by_url(kwargs['file_image']) kwargs['file_image'] = await download_image_by_url(kwargs['file_image'])
if self.test_mode:
return kwargs
await self.client.api.post_dms(**kwargs) await self.client.api.post_dms(**kwargs)
else:
raise ValueError("Unknown target type.")
async def send_msg(self, target: Dict[str, str], result_message: Union[List[BaseMessageComponent], str]): async def send_msg(self, target: Dict[str, str], result_message: CommandResult):
''' '''
以主动的方式给用户、群或者频道发送一条消息。 以主动的方式给频道用户、群、频道或者消息列表用户(QQ用户)发送一条消息。
`target` 接收一个 dict 类型的值引用。 `target` 接收一个 dict 类型的值引用。
- 如果目标是 QQ 群,请添加 key `group_openid`。 - 如果目标是 QQ 群,请添加 key `group_openid`。
- 如果目标是 频道消息,请添加 key `channel_id`。 - 如果目标是 频道消息,请添加 key `channel_id`。
- 如果目标是 频道私聊,请添加 key `guild_id`。 - 如果目标是 频道私聊,请添加 key `guild_id`。
- 如果目标是 QQ 用户,请添加 key `openid`。
''' '''
if isinstance(result_message, list): plain_text, image_path = await self._parse_to_qqofficial(result_message.message_chain)
plain_text, image_path = await self._parse_to_qqofficial(result_message)
else:
plain_text = result_message
payload = { payload = {
'content': plain_text, 'content': plain_text,
'file_image': image_path,
**target **target
} }
if image_path:
payload['file_image'] = image_path
await self._reply(**payload) await self._reply(**payload)
def wait_for_message(self, channel_id: int) -> AstrBotMessage: def wait_for_message(self, channel_id: int) -> AstrBotMessage:
+10
View File
@@ -24,6 +24,7 @@ class CommandResult():
self.success = success self.success = success
self.message_chain = message_chain self.message_chain = message_chain
self.command_name = command_name self.command_name = command_name
self.is_use_t2i = None # default
def message(self, message: str): def message(self, message: str):
''' '''
@@ -61,6 +62,15 @@ class CommandResult():
''' '''
self.message_chain = [Image.fromFileSystem(path), ] self.message_chain = [Image.fromFileSystem(path), ]
return self return self
# def use_t2i(self, use_t2i: bool):
# '''
# 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。
# CommandResult().use_t2i(False)
# '''
# self.is_use_t2i = use_t2i
# return self
def _result_tuple(self): def _result_tuple(self):
return (self.success, self.message_chain, self.command_name) return (self.success, self.message_chain, self.command_name)
+1 -1
View File
@@ -1 +1 @@
VERSION = '3.3.0' VERSION = '3.3.4'
+2 -5
View File
@@ -43,13 +43,10 @@ class AstrMessageEvent():
context, context,
session_id) session_id)
return ame return ame
@dataclass @dataclass
class MessageResult(): class MessageResult():
result_message: Union[str, list] result_message: Union[str, list]
is_command_call: Optional[bool] = False is_command_call: Optional[bool] = False
use_t2i: Optional[bool] = None # None 为跟随用户设置
callback: Optional[callable] = None callback: Optional[callable] = None
+9
View File
@@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader from util.image_uploader import ImageUploader
from util.updator.plugin_updator import PluginUpdator from util.updator.plugin_updator import PluginUpdator
from model.plugin.command import PluginCommandBridge from model.plugin.command import PluginCommandBridge
from model.provider.provider import Provider
class Context: class Context:
@@ -68,6 +69,14 @@ class Context:
''' '''
task = asyncio.create_task(coro, name=task_name) task = asyncio.create_task(coro, name=task_name)
self.ext_tasks.append(task) self.ext_tasks.append(task)
def register_provider(self, llm_name: str, provider: Provider, origin: str = ''):
'''
注册一个提供 LLM 资源的 Provider。
`provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。
'''
self.llms.append(RegisteredLLM(llm_name, provider, origin))
def find_platform(self, platform_name: str) -> RegisteredPlatform: def find_platform(self, platform_name: str) -> RegisteredPlatform:
for platform in self.platforms: for platform in self.platforms:
+2 -2
View File
@@ -122,7 +122,7 @@ class FuncCall():
_c = 0 _c = 0
while _c < 3: while _c < 3:
try: try:
res = self.provider.text_chat(prompt, session_id) res = self.provider.text_chat(prompt=prompt, session_id=session_id)
if res.find('```') != -1: if res.find('```') != -1:
res = res[res.find('```json') + 7: res.rfind('```')] res = res[res.find('```json') + 7: res.rfind('```')]
gu.log("REVGPT func_call json result", gu.log("REVGPT func_call json result",
@@ -187,7 +187,7 @@ class FuncCall():
_c = 0 _c = 0
while _c < 5: while _c < 5:
try: try:
res = self.provider.text_chat(after_prompt, session_id) res = self.provider.text_chat(prompt=after_prompt, session_id=session_id)
# 截取```之间的内容 # 截取```之间的内容
gu.log( gu.log(
"DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) "DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"])
+5 -5
View File
@@ -127,7 +127,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
function_invoked_ret = "" function_invoked_ret = ""
if official_fc: if official_fc:
# we use official function-calling # we use official function-calling
result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func()) result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func())
if isinstance(result, Function): if isinstance(result, Function):
logger.debug(f"web_searcher - function-calling: {result}") logger.debug(f"web_searcher - function-calling: {result}")
func_obj = None func_obj = None
@@ -136,14 +136,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
func_obj = i["func_obj"] func_obj = i["func_obj"]
break break
if not func_obj: if not func_obj:
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
try: try:
args = json.loads(result.arguments) args = json.loads(result.arguments)
function_invoked_ret = await func_obj(**args) function_invoked_ret = await func_obj(**args)
has_func = True has_func = True
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)"
else: else:
return result return result
else: else:
@@ -162,7 +162,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
has_func = True has_func = True
if has_func: if has_func:
await provider.forget(session_id) await provider.forget(session_id=session_id, )
summary_prompt = f""" summary_prompt = f"""
你是一个专业且高效的助手,你的任务是 你是一个专业且高效的助手,你的任务是
1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; 1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结;
@@ -178,6 +178,6 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False):
# 相关材料 # 相关材料
{function_invoked_ret}""" {function_invoked_ret}"""
ret = await provider.text_chat(summary_prompt, session_id) ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id)
return ret return ret
return function_invoked_ret return function_invoked_ret
+38
View File
@@ -9,6 +9,44 @@ def init_configs():
''' '''
cc = CmdConfig() cc = CmdConfig()
cc.init_attributes("qqbot", {
"enable": False,
"appid": "",
"token": "",
})
cc.init_attributes("gocqbot", {
"enable": False,
})
cc.init_attributes("uniqueSessionMode", False)
cc.init_attributes("dump_history_interval", 10)
cc.init_attributes("limit", {
"time": 60,
"count": 30,
})
cc.init_attributes("notice", "")
cc.init_attributes("direct_message_mode", True)
cc.init_attributes("reply_prefix", "")
cc.init_attributes("baidu_aip", {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
})
cc.init_attributes("openai", {
"key": [],
"api_base": "",
"chatGPTConfigs": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"total_tokens_limit": 10000,
})
cc.init_attributes("qq_forward_threshold", 200) cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes("qq_welcome", "") cc.init_attributes("qq_welcome", "")
cc.init_attributes("qq_pic_mode", True) cc.init_attributes("qq_pic_mode", True)
+2 -1
View File
@@ -8,4 +8,5 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。
from model.platform import Platform from model.platform import Platform
from model.platform.qq_nakuru import QQGOCQ from model.platform.qq_nakuru import QQGOCQ
from model.platform.qq_official import QQOfficial from model.platform.qq_official import QQOfficial
from model.platform.qq_aiocqhttp import AIOCQHTTP
+1 -1
View File
@@ -32,7 +32,7 @@ class NetworkRenderStrategy(RenderStrategy):
"options": { "options": {
"full_page": True, "full_page": True,
"type": "jpeg", "type": "jpeg",
"quality": 25, "quality": 40,
} }
} }
+3 -2
View File
@@ -3,7 +3,7 @@ from util.updator.zip_updator import ReleaseInfo, RepoZipUpdator
from SparkleLogging.utils.core import LogManager from SparkleLogging.utils.core import LogManager
from logging import Logger from logging import Logger
from type.config import VERSION from type.config import VERSION
from util.io import on_error from util.io import on_error, download_file
logger: Logger = LogManager.GetLogger(log_name='astrbot') logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -58,7 +58,8 @@ class AstrBotUpdator(RepoZipUpdator):
raise Exception(f"未找到版本号为 {version} 的更新文件。") raise Exception(f"未找到版本号为 {version} 的更新文件。")
try: try:
self.download_from_repo_url("temp", data['zipball_url']) # self.download_from_repo_url("temp", file_url)
download_file(file_url, "temp.zip")
self.unzip_file("temp.zip", self.MAIN_PATH) self.unzip_file("temp.zip", self.MAIN_PATH)
except BaseException as e: except BaseException as e:
raise e raise e