Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 14dbdb2d83 | |||
| abda226d63 | |||
| a2dc6f0a49 | |||
| 7a94c26333 | |||
| 9b1ffb384b | |||
| 9566bfe122 | |||
| 89ff103bda | |||
| 6c788db53a | |||
| 344b5fa419 | |||
| c6d161b837 | |||
| 2065ba0c60 | |||
| a481fd1a3e | |||
| c50bcdbdb9 | |||
| 36a2a7632c | |||
| e77b7014e6 | |||
| d57fd0f827 | |||
| 6a83d2a62a | |||
| 2d29726c18 | |||
| b241b0f954 | |||
| 171dd1dc02 |
@@ -12,9 +12,9 @@
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
</a>
|
||||
|
||||
<a href="https://astrbot.soulter.top/center">项目部署</a> |
|
||||
<a href="https://astrbot.soulter.top/docs/main">快速开始</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a> |
|
||||
<a href="https://astrbot.soulter.top/center/docs/%E5%BC%80%E5%8F%91/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91">插件开发</a>
|
||||
<a href="https://astrbot.soulter.top/docs/develop/plugin4p">插件开发</a>
|
||||
</div>
|
||||
|
||||
## 🛠️ 功能
|
||||
|
||||
+19
-7
@@ -10,6 +10,7 @@ from model.plugin.manager import PluginManager
|
||||
from model.platform.manager import PlatformManager
|
||||
from typing import Dict, List, Union
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
@@ -23,15 +24,26 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
self.config_helper: CmdConfig = CmdConfig()
|
||||
self.config_helper = CmdConfig()
|
||||
|
||||
# load configs and ensure the backward compatibility
|
||||
init_configs()
|
||||
try_migrate_config()
|
||||
self.configs = inject_to_context(self.context)
|
||||
logger.info("AstrBot v" + self.context.version)
|
||||
self.context.config_helper = self.config_helper
|
||||
self.context.base_config = self.config_helper.cached_config
|
||||
|
||||
self.context.default_personality = {
|
||||
"name": "default",
|
||||
"prompt": self.context.base_config.get("default_personality_str", ""),
|
||||
}
|
||||
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
|
||||
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
|
||||
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
|
||||
self.context.nick = nick_qq
|
||||
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
|
||||
self.context.version = VERSION
|
||||
|
||||
logger.info("AstrBot v" + self.context.version)
|
||||
|
||||
# apply proxy settings
|
||||
http_proxy = self.context.base_config.get("http_proxy")
|
||||
https_proxy = self.context.base_config.get("https_proxy")
|
||||
@@ -93,9 +105,9 @@ class AstrBotBootstrap():
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def load_llm(self):
|
||||
if 'openai' in self.configs and \
|
||||
len(self.configs['openai']['key']) and \
|
||||
self.configs['openai']['key'][0] is not None:
|
||||
if 'openai' in self.config_helper.cached_config and \
|
||||
len(self.config_helper.cached_config['openai']['key']) and \
|
||||
self.config_helper.cached_config['openai']['key'][0] is not None:
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
|
||||
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
|
||||
|
||||
@@ -112,9 +112,11 @@ class MessageHandler():
|
||||
self.rate_limit_helper = RateLimitHelper(context)
|
||||
self.content_safety_helper = ContentSafetyHelper(context)
|
||||
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
|
||||
if self.llm_wake_prefix:
|
||||
self.llm_wake_prefix = self.llm_wake_prefix.strip()
|
||||
self.nicks = self.context.nick
|
||||
self.provider = provider
|
||||
self.reply_prefix = self.context.reply_prefix
|
||||
self.reply_prefix = str(self.context.reply_prefix)
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
@@ -144,7 +146,8 @@ class MessageHandler():
|
||||
if msg_plain.startswith(nick):
|
||||
msg_plain = msg_plain.removeprefix(nick)
|
||||
break
|
||||
|
||||
message.message_str = msg_plain
|
||||
|
||||
# scan candidate commands
|
||||
cmd_res = await self.command_manager.scan_command(message, self.context)
|
||||
if cmd_res:
|
||||
@@ -156,10 +159,12 @@ class MessageHandler():
|
||||
)
|
||||
|
||||
# check if the message is a llm-wake-up command
|
||||
if not msg_plain.startswith(self.llm_wake_prefix):
|
||||
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
|
||||
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
|
||||
return
|
||||
|
||||
if not provider:
|
||||
logger.debug("没有任何 LLM 可用,忽略。")
|
||||
return
|
||||
|
||||
# check the content safety
|
||||
|
||||
+5
-2
@@ -45,6 +45,10 @@ class AstrBotDashBoard():
|
||||
def index():
|
||||
# 返回页面
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/auth/login")
|
||||
def _():
|
||||
return self.dashboard_be.send_static_file("index.html")
|
||||
|
||||
@self.dashboard_be.get("/config")
|
||||
def rt_config():
|
||||
@@ -86,12 +90,11 @@ class AstrBotDashBoard():
|
||||
|
||||
@self.dashboard_be.post("/api/change_password")
|
||||
def change_password():
|
||||
password = self.context.base_config("dashboard_password", "")
|
||||
password = self.context.base_config.get("dashboard_password", "")
|
||||
# 获得请求体
|
||||
post_data = request.json
|
||||
if post_data["password"] == password:
|
||||
self.context.config_helper.put("dashboard_password", post_data["new_password"])
|
||||
self.context.base_config['dashboard_password'] = post_data["new_password"]
|
||||
return Response(
|
||||
status="success",
|
||||
message="修改成功。",
|
||||
|
||||
@@ -4,12 +4,13 @@ import asyncio
|
||||
import sys
|
||||
import warnings
|
||||
import traceback
|
||||
import mimetypes
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
logo_tmpl = """
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
@@ -42,6 +43,11 @@ def check_env():
|
||||
|
||||
os.makedirs("data/config", exist_ok=True)
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
|
||||
# workaround for issue #181
|
||||
mimetypes.add_type("text/javascript", ".js")
|
||||
mimetypes.add_type("text/javascript", ".mjs")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_env()
|
||||
|
||||
@@ -62,14 +62,16 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("你没有权限使用该指令。")
|
||||
l = message_str.split(" ")
|
||||
if len(l) == 1:
|
||||
return CommandResult().message("设置机器人唤醒词,支持多唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称1 昵称2 昵称3")
|
||||
nick = l[1:]
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}")
|
||||
nick = l[1].strip()
|
||||
if not nick:
|
||||
return CommandResult().message("wake: 请指定唤醒词。")
|
||||
context.config_helper.put("nick_qq", nick)
|
||||
context.nick = tuple(nick)
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain=f"已经成功将唤醒词设定为 {nick}",
|
||||
message_chain=f"已经成功将唤醒词设定为 {nick}。",
|
||||
)
|
||||
|
||||
def update(self, message: AstrMessageEvent, context: Context):
|
||||
@@ -230,15 +232,17 @@ class InternalCommandHandler:
|
||||
)
|
||||
|
||||
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
|
||||
p = context.config_helper.get("qq_pic_mode", True)
|
||||
p = context.t2i_mode
|
||||
if p:
|
||||
context.config_helper.put("qq_pic_mode", False)
|
||||
context.t2i_mode = False
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain="已关闭文本转图片模式。",
|
||||
)
|
||||
context.config_helper.put("qq_pic_mode", True)
|
||||
context.t2i_mode = True
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
|
||||
@@ -20,7 +20,8 @@ class CommandMetadata():
|
||||
inner_command: bool
|
||||
plugin_metadata: PluginMetadata
|
||||
handler: callable
|
||||
description: str
|
||||
use_regex: bool = False
|
||||
description: str = ""
|
||||
|
||||
class CommandManager():
|
||||
def __init__(self):
|
||||
@@ -33,10 +34,13 @@ class CommandManager():
|
||||
description: str,
|
||||
priority: int,
|
||||
handler: callable,
|
||||
use_regex: bool = False,
|
||||
plugin_metadata: PluginMetadata = None,
|
||||
):
|
||||
'''
|
||||
优先级越高,越先被处理。
|
||||
|
||||
use_regex: 是否使用正则表达式匹配指令。
|
||||
'''
|
||||
if command in self.commands_handler:
|
||||
raise ValueError(f"Command {command} already exists.")
|
||||
@@ -48,6 +52,7 @@ class CommandManager():
|
||||
inner_command=plugin_metadata == None,
|
||||
plugin_metadata=plugin_metadata,
|
||||
handler=handler,
|
||||
use_regex=use_regex,
|
||||
description=description
|
||||
)
|
||||
if plugin_metadata:
|
||||
@@ -64,13 +69,24 @@ class CommandManager():
|
||||
break
|
||||
if not plugin:
|
||||
logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}。")
|
||||
self.register(request.command_name, request.description, request.priority, request.handler, plugin.metadata)
|
||||
else:
|
||||
self.register(command=request.command_name,
|
||||
description=request.description,
|
||||
priority=request.priority,
|
||||
handler=request.handler,
|
||||
use_regex=request.use_regex,
|
||||
plugin_metadata=plugin.metadata)
|
||||
self.plugin_commands_waitlist = []
|
||||
|
||||
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
|
||||
message_str = message_event.message_str
|
||||
for _, command in self.commands:
|
||||
if message_str.startswith(command):
|
||||
trig = False
|
||||
if self.commands_handler[command].use_regex:
|
||||
trig = self.command_parser.regex_match(message_str, command)
|
||||
else:
|
||||
trig = message_str.startswith(command)
|
||||
if trig:
|
||||
logger.info(f"触发 {command} 指令。")
|
||||
command_result = await self.execute_handler(command, message_event, context)
|
||||
if command_result.hit:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
|
||||
class CommandTokens():
|
||||
def __init__(self) -> None:
|
||||
self.tokens = []
|
||||
@@ -16,4 +18,8 @@ class CommandParser():
|
||||
cmd_tokens = CommandTokens()
|
||||
cmd_tokens.tokens = message.split(" ")
|
||||
cmd_tokens.len = len(cmd_tokens.tokens)
|
||||
return cmd_tokens
|
||||
return cmd_tokens
|
||||
|
||||
def regex_match(self, message: str, command: str) -> bool:
|
||||
return re.search(command, message, re.MULTILINE) is not None
|
||||
|
||||
@@ -48,6 +48,14 @@ class AIOCQHTTP(Platform):
|
||||
abm.message = []
|
||||
|
||||
message_str = ""
|
||||
if not isinstance(event.message, list):
|
||||
err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。"
|
||||
logger.critical(err)
|
||||
try:
|
||||
self.bot.send(event, err)
|
||||
except BaseException as e:
|
||||
logger.error(f"回复消息失败: {e}")
|
||||
return
|
||||
for m in event.message:
|
||||
t = m['type']
|
||||
a = None
|
||||
@@ -75,14 +83,12 @@ class AIOCQHTTP(Platform):
|
||||
abm = self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
# return {'reply': event.message}
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
# return {'reply': event.message}
|
||||
|
||||
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
|
||||
@@ -117,8 +123,8 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.sender.user_id)
|
||||
if sender_id == self.context.config_helper.get('admin_qq', '') or \
|
||||
sender_id in self.context.config_helper.get('other_admins', []):
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -154,7 +160,7 @@ class AIOCQHTTP(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
|
||||
@@ -112,8 +112,8 @@ class QQGOCQ(Platform):
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.raw_message.user_id)
|
||||
if sender_id == self.context.config_helper.get('admin_qq', '') or \
|
||||
sender_id in self.context.config_helper.get('other_admins', []):
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -152,7 +152,7 @@ class QQGOCQ(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
@@ -186,7 +186,7 @@ class QQGOCQ(Platform):
|
||||
plain_text_len += len(i.text)
|
||||
elif isinstance(i, Image):
|
||||
image_num += 1
|
||||
if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200):
|
||||
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
|
||||
# 删除At
|
||||
for i in message_chain:
|
||||
if isinstance(i, At):
|
||||
|
||||
@@ -43,6 +43,11 @@ class botClient(Client):
|
||||
# 转换层
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
await self.platform.handle_msg(abm)
|
||||
|
||||
# 收到 C2C 消息
|
||||
async def on_c2c_message_create(self, message: botpy.message.C2CMessage):
|
||||
abm = self.platform._parse_from_qqofficial(message, MessageType.FRIEND_MESSAGE)
|
||||
await self.platform.handle_msg(abm)
|
||||
|
||||
|
||||
class QQOfficial(Platform):
|
||||
@@ -110,11 +115,17 @@ class QQOfficial(Platform):
|
||||
abm.tag = "qqchan"
|
||||
msg: List[BaseMessageComponent] = []
|
||||
|
||||
if message_type == MessageType.GROUP_MESSAGE:
|
||||
abm.sender = MessageMember(
|
||||
message.author.member_openid,
|
||||
""
|
||||
)
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
|
||||
if isinstance(message, botpy.message.GroupMessage):
|
||||
abm.sender = MessageMember(
|
||||
message.author.member_openid,
|
||||
""
|
||||
)
|
||||
else:
|
||||
abm.sender = MessageMember(
|
||||
message.author.user_openid,
|
||||
""
|
||||
)
|
||||
abm.message_str = message.content.strip()
|
||||
abm.self_id = "unknown_selfid"
|
||||
|
||||
@@ -129,8 +140,7 @@ class QQOfficial(Platform):
|
||||
msg.append(img)
|
||||
abm.message = msg
|
||||
|
||||
elif message_type == MessageType.GUILD_MESSAGE or message_type == MessageType.FRIEND_MESSAGE:
|
||||
# 目前对于 FRIEND_MESSAGE 只处理频道私聊
|
||||
elif isinstance(message, botpy.message.Message) or isinstance(message, botpy.message.DirectMessage):
|
||||
try:
|
||||
abm.self_id = str(message.mentions[0].id)
|
||||
except:
|
||||
@@ -178,7 +188,7 @@ class QQOfficial(Platform):
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
assert isinstance(message.raw_message, (botpy.message.Message,
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage))
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
is_group = message.type != MessageType.FRIEND_MESSAGE
|
||||
|
||||
_t = "/私聊" if not is_group else ""
|
||||
@@ -199,8 +209,8 @@ class QQOfficial(Platform):
|
||||
|
||||
# 解析出 role
|
||||
sender_id = message.sender.user_id
|
||||
if sender_id == self.context.config_helper.get('admin_qqchan', None) or \
|
||||
sender_id in self.context.config_helper.get('other_admins', None):
|
||||
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
|
||||
sender_id in self.context.base_config.get('other_admins', None):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -230,7 +240,7 @@ class QQOfficial(Platform):
|
||||
'''
|
||||
source = message.raw_message
|
||||
assert isinstance(source, (botpy.message.Message,
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage))
|
||||
botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
|
||||
logger.info(
|
||||
f"{message.sender.nickname}({message.sender.user_id}) <- {self.parse_message_outline(result_message)}")
|
||||
|
||||
@@ -239,7 +249,7 @@ class QQOfficial(Platform):
|
||||
msg_ref = None
|
||||
rendered_images = []
|
||||
|
||||
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list):
|
||||
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(result_message)
|
||||
|
||||
if isinstance(result_message, list):
|
||||
@@ -258,12 +268,14 @@ class QQOfficial(Platform):
|
||||
'message_reference': msg_ref
|
||||
}
|
||||
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if isinstance(message.raw_message, botpy.message.GroupMessage):
|
||||
data['group_openid'] = str(source.group_openid)
|
||||
elif message.type == MessageType.GUILD_MESSAGE:
|
||||
elif isinstance(message.raw_message, botpy.message.Message):
|
||||
data['channel_id'] = source.channel_id
|
||||
elif message.type == MessageType.FRIEND_MESSAGE:
|
||||
elif isinstance(message.raw_message, botpy.message.DirectMessage):
|
||||
data['guild_id'] = source.guild_id
|
||||
elif isinstance(message.raw_message, botpy.message.C2CMessage):
|
||||
data['openid'] = source.author.user_openid
|
||||
if image_path:
|
||||
data['file_image'] = image_path
|
||||
if rendered_images:
|
||||
@@ -308,7 +320,7 @@ class QQOfficial(Platform):
|
||||
return await self._reply(**data)
|
||||
|
||||
async def _reply(self, **kwargs):
|
||||
if 'group_openid' in kwargs:
|
||||
if 'group_openid' in kwargs or 'openid' in kwargs:
|
||||
# QQ群组消息
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
file_image_path = kwargs['file_image'].replace("file:///", "")
|
||||
@@ -320,14 +332,20 @@ class QQOfficial(Platform):
|
||||
logger.debug(f"上传图片: {file_image_path}")
|
||||
image_url = await self.context.image_uploader.upload_image(file_image_path)
|
||||
logger.debug(f"上传成功: {image_url}")
|
||||
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
|
||||
if 'group_openid' in kwargs:
|
||||
media = await self.client.api.post_group_file(kwargs['group_openid'], 1, image_url)
|
||||
elif 'openid' in kwargs:
|
||||
media = await self.client.api.post_c2c_file(kwargs['openid'], 1, image_url)
|
||||
del kwargs['file_image']
|
||||
kwargs['media'] = media
|
||||
logger.debug(f"发送群图片: {media}")
|
||||
kwargs['msg_type'] = 7 # 富媒体
|
||||
if self.test_mode:
|
||||
return kwargs
|
||||
await self.client.api.post_group_message(**kwargs)
|
||||
if 'group_openid' in kwargs:
|
||||
await self.client.api.post_group_message(**kwargs)
|
||||
elif 'openid' in kwargs:
|
||||
await self.client.api.post_c2c_message(**kwargs)
|
||||
elif 'channel_id' in kwargs:
|
||||
# 频道消息
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
@@ -338,7 +356,7 @@ class QQOfficial(Platform):
|
||||
if self.test_mode:
|
||||
return kwargs
|
||||
await self.client.api.post_message(**kwargs)
|
||||
else:
|
||||
elif 'guild_id' in kwargs:
|
||||
# 频道私聊消息
|
||||
if 'file_image' in kwargs and kwargs['file_image']:
|
||||
kwargs['file_image'] = kwargs['file_image'].replace("file:///", "")
|
||||
@@ -347,16 +365,19 @@ class QQOfficial(Platform):
|
||||
if self.test_mode:
|
||||
return kwargs
|
||||
await self.client.api.post_dms(**kwargs)
|
||||
else:
|
||||
raise ValueError("Unknown target type.")
|
||||
|
||||
async def send_msg(self, target: Dict[str, str], result_message: CommandResult):
|
||||
'''
|
||||
以主动的方式给用户、群或者频道发送一条消息。
|
||||
以主动的方式给频道用户、群、频道或者消息列表用户(QQ用户)发送一条消息。
|
||||
|
||||
`target` 接收一个 dict 类型的值引用。
|
||||
|
||||
- 如果目标是 QQ 群,请添加 key `group_openid`。
|
||||
- 如果目标是 频道消息,请添加 key `channel_id`。
|
||||
- 如果目标是 频道私聊,请添加 key `guild_id`。
|
||||
- 如果目标是 QQ 用户,请添加 key `openid`。
|
||||
'''
|
||||
plain_text, image_path = await self._parse_to_qqofficial(result_message.message_chain)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ class CommandRegisterRequest():
|
||||
description: str
|
||||
priority: int
|
||||
handler: Callable
|
||||
use_regex: bool = False
|
||||
plugin_name: str = None
|
||||
|
||||
class PluginCommandBridge():
|
||||
@@ -20,6 +21,6 @@ class PluginCommandBridge():
|
||||
self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
|
||||
self.cached_plugins = cached_plugins
|
||||
|
||||
def register_command(self, plugin_name, command_name, description, priority, handler):
|
||||
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, plugin_name))
|
||||
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False):
|
||||
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name))
|
||||
|
||||
@@ -123,7 +123,7 @@ class PluginManager():
|
||||
return p
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins)
|
||||
plugin = self.get_registered_plugin(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
root_dir_name = plugin.root_dir_name
|
||||
@@ -133,7 +133,7 @@ class PluginManager():
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
def update_plugin(self, plugin_name: str):
|
||||
plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins)
|
||||
plugin = self.get_registered_plugin(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
|
||||
@@ -155,6 +155,8 @@ class PluginManager():
|
||||
p = plugin['module']
|
||||
module_path = plugin['module_path']
|
||||
root_dir_name = plugin['pname']
|
||||
|
||||
logger.info(f"正在加载插件 {root_dir_name} ...")
|
||||
|
||||
# self.check_plugin_dept_update(cached_plugins, root_dir_name)
|
||||
|
||||
@@ -166,8 +168,10 @@ class PluginManager():
|
||||
try:
|
||||
# 尝试传入 ctx
|
||||
obj = getattr(module, cls[0])(context=self.context)
|
||||
except:
|
||||
except TypeError:
|
||||
obj = getattr(module, cls[0])()
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
metadata = None
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
os.makedirs("data/openai", exist_ok=True)
|
||||
|
||||
self.cc = CmdConfig
|
||||
self.context = context
|
||||
self.key_data_path = "data/openai/keys.json"
|
||||
self.api_keys = []
|
||||
self.chosen_api_key = None
|
||||
@@ -78,7 +78,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
self.model_configs: Dict = cfg['chatGPTConfigs']
|
||||
super().set_curr_model(self.model_configs['model'])
|
||||
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
|
||||
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
|
||||
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||
self.session_memory_lock = threading.Lock()
|
||||
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
||||
@@ -385,7 +385,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
assert isinstance(completion, ChatCompletion)
|
||||
logger.debug(f"openai completion: {completion.usage}")
|
||||
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("OpenAI API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
usage_tokens = completion.usage.total_tokens
|
||||
@@ -492,7 +494,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model_configs['model'] = model
|
||||
self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model)
|
||||
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
|
||||
super().set_curr_model(model)
|
||||
|
||||
def get_configs(self):
|
||||
|
||||
+75
-1
@@ -1 +1,75 @@
|
||||
VERSION = '3.3.3'
|
||||
VERSION = '3.3.7'
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"qqbot": {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"token": "",
|
||||
},
|
||||
"gocqbot": {
|
||||
"enable": False,
|
||||
},
|
||||
"uniqueSessionMode": False,
|
||||
"dump_history_interval": 10,
|
||||
"limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
},
|
||||
"notice": "",
|
||||
"direct_message_mode": True,
|
||||
"reply_prefix": "",
|
||||
"baidu_aip": {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
},
|
||||
"openai": {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"chatGPTConfigs": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
"total_tokens_limit": 10000,
|
||||
},
|
||||
"qq_forward_threshold": 200,
|
||||
"qq_welcome": "",
|
||||
"qq_pic_mode": True,
|
||||
"gocq_host": "127.0.0.1",
|
||||
"gocq_http_port": 5700,
|
||||
"gocq_websocket_port": 6700,
|
||||
"gocq_react_group": True,
|
||||
"gocq_react_guild": True,
|
||||
"gocq_react_friend": True,
|
||||
"gocq_react_group_increase": True,
|
||||
"other_admins": [],
|
||||
"CHATGPT_BASE_URL": "",
|
||||
"qqbot_secret": "",
|
||||
"qqofficial_enable_group_message": False,
|
||||
"admin_qq": "",
|
||||
"nick_qq": ["/", "!"],
|
||||
"admin_qqchan": "",
|
||||
"llm_env_prompt": "",
|
||||
"llm_wake_prefix": "",
|
||||
"default_personality_str": "",
|
||||
"openai_image_generate": {
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
},
|
||||
"http_proxy": "",
|
||||
"https_proxy": "",
|
||||
"dashboard_username": "",
|
||||
"dashboard_password": "",
|
||||
"aiocqhttp": {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 0,
|
||||
}
|
||||
}
|
||||
+15
-12
@@ -28,37 +28,40 @@ class Context:
|
||||
|
||||
self.unique_session = False # 独立会话
|
||||
self.version: str = None # 机器人版本
|
||||
self.nick = None # gocq 的唤醒词
|
||||
self.stat = {}
|
||||
self.nick: tuple = None # gocq 的唤醒词
|
||||
self.t2i_mode = False
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
self.reply_prefix = ""
|
||||
|
||||
self.metrics_uploader = None
|
||||
self.updator: AstrBotUpdator = None
|
||||
self.plugin_updator: PluginUpdator = None
|
||||
self.metrics_uploader = None
|
||||
|
||||
self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins)
|
||||
self.image_renderer = TextToImageRenderer()
|
||||
self.image_uploader = ImageUploader()
|
||||
self.message_handler = None # see astrbot/message/handler.py
|
||||
self.ext_tasks: List[Task] = []
|
||||
|
||||
# useless
|
||||
self.reply_prefix = ""
|
||||
|
||||
def register_commands(self,
|
||||
plugin_name: str,
|
||||
command_name: str,
|
||||
description: str,
|
||||
priority: int,
|
||||
handler: callable):
|
||||
handler: callable,
|
||||
use_regex: bool = False):
|
||||
'''
|
||||
注册插件指令。
|
||||
|
||||
`plugin_name`: 插件名,注意需要和你的 metadata 中的一致。
|
||||
`command_name`: 指令名,如 "help"。不需要带前缀。
|
||||
`description`: 指令描述。
|
||||
`priority`: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
|
||||
`handler`: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
|
||||
@param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
|
||||
@param command_name: 指令名,如 "help"。不需要带前缀。
|
||||
@param description: 指令描述。
|
||||
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
|
||||
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
|
||||
@param use_regex: 是否使用正则表达式匹配指令名。
|
||||
'''
|
||||
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler)
|
||||
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex)
|
||||
|
||||
def register_task(self, coro: Awaitable, task_name: str):
|
||||
'''
|
||||
|
||||
+38
-29
@@ -1,19 +1,31 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Union
|
||||
from type.config import DEFAULT_CONFIG
|
||||
|
||||
cpath = "data/cmd_config.json"
|
||||
|
||||
def check_exist():
|
||||
if not os.path.exists(cpath):
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump({}, f, indent=4, ensure_ascii=False)
|
||||
json.dump({}, f, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
class CmdConfig():
|
||||
def __init__(self) -> None:
|
||||
self.cached_config: dict = {}
|
||||
self.init_configs()
|
||||
|
||||
def init_configs(self):
|
||||
'''
|
||||
初始化必需的配置项
|
||||
'''
|
||||
self.init_config_items(DEFAULT_CONFIG)
|
||||
|
||||
@staticmethod
|
||||
def get(key, default=None):
|
||||
'''
|
||||
从文件系统中直接获取配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
@@ -22,28 +34,33 @@ class CmdConfig():
|
||||
else:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_all():
|
||||
def get_all(self):
|
||||
'''
|
||||
从文件系统中获取所有配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
return json.load(f)
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
conf = json.loads(conf_str)
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def put(key, value):
|
||||
check_exist()
|
||||
def put(self, key, value):
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
d[key] = value
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=4, ensure_ascii=False)
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
self.cached_config[key] = value
|
||||
|
||||
@staticmethod
|
||||
def put_by_dot_str(key: str, value):
|
||||
'''
|
||||
根据点分割的字符串,将值写入配置文件
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
_d = d
|
||||
@@ -54,30 +71,22 @@ class CmdConfig():
|
||||
else:
|
||||
_d = _d[_ks[i]]
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=4, ensure_ascii=False)
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
@staticmethod
|
||||
def init_attributes(key: Union[str, list], init_val=""):
|
||||
check_exist()
|
||||
conf_str = ''
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'):
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
d = json.loads(conf_str)
|
||||
def init_config_items(self, d: dict):
|
||||
conf = self.get_all()
|
||||
|
||||
if not self.cached_config:
|
||||
self.cached_config = conf
|
||||
|
||||
_tag = False
|
||||
|
||||
if isinstance(key, str):
|
||||
if key not in d:
|
||||
d[key] = init_val
|
||||
for key, val in d.items():
|
||||
if key not in conf:
|
||||
conf[key] = val
|
||||
_tag = True
|
||||
elif isinstance(key, list):
|
||||
for k in key:
|
||||
if k not in d:
|
||||
d[k] = init_val
|
||||
_tag = True
|
||||
if _tag:
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=4, ensure_ascii=False)
|
||||
json.dump(conf, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
+1
-132
@@ -1,89 +1,5 @@
|
||||
import json, os
|
||||
from util.cmd_config import CmdConfig
|
||||
from type.config import VERSION
|
||||
from type.types import Context
|
||||
|
||||
def init_configs():
|
||||
'''
|
||||
初始化必需的配置项
|
||||
'''
|
||||
cc = CmdConfig()
|
||||
|
||||
cc.init_attributes("qqbot", {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"token": "",
|
||||
})
|
||||
cc.init_attributes("gocqbot", {
|
||||
"enable": False,
|
||||
})
|
||||
cc.init_attributes("uniqueSessionMode", False)
|
||||
cc.init_attributes("dump_history_interval", 10)
|
||||
cc.init_attributes("limit", {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
})
|
||||
cc.init_attributes("notice", "")
|
||||
cc.init_attributes("direct_message_mode", True)
|
||||
cc.init_attributes("reply_prefix", "")
|
||||
cc.init_attributes("baidu_aip", {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
})
|
||||
cc.init_attributes("openai", {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"chatGPTConfigs": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
"total_tokens_limit": 10000,
|
||||
})
|
||||
|
||||
|
||||
cc.init_attributes("qq_forward_threshold", 200)
|
||||
cc.init_attributes("qq_welcome", "")
|
||||
cc.init_attributes("qq_pic_mode", True)
|
||||
cc.init_attributes("gocq_host", "127.0.0.1")
|
||||
cc.init_attributes("gocq_http_port", 5700)
|
||||
cc.init_attributes("gocq_websocket_port", 6700)
|
||||
cc.init_attributes("gocq_react_group", True)
|
||||
cc.init_attributes("gocq_react_guild", True)
|
||||
cc.init_attributes("gocq_react_friend", True)
|
||||
cc.init_attributes("gocq_react_group_increase", True)
|
||||
cc.init_attributes("other_admins", [])
|
||||
cc.init_attributes("CHATGPT_BASE_URL", "")
|
||||
cc.init_attributes("qqbot_secret", "")
|
||||
cc.init_attributes("qqofficial_enable_group_message", False)
|
||||
cc.init_attributes("admin_qq", "")
|
||||
cc.init_attributes("nick_qq", ["!", "!", "ai"])
|
||||
cc.init_attributes("admin_qqchan", "")
|
||||
cc.init_attributes("llm_env_prompt", "")
|
||||
cc.init_attributes("llm_wake_prefix", "")
|
||||
cc.init_attributes("default_personality_str", "")
|
||||
cc.init_attributes("openai_image_generate", {
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
})
|
||||
cc.init_attributes("http_proxy", "")
|
||||
cc.init_attributes("https_proxy", "")
|
||||
cc.init_attributes("dashboard_username", "")
|
||||
cc.init_attributes("dashboard_password", "")
|
||||
|
||||
# aiocqhttp 适配器
|
||||
cc.init_attributes("aiocqhttp", {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 0,
|
||||
})
|
||||
|
||||
def try_migrate_config():
|
||||
'''
|
||||
@@ -97,51 +13,4 @@ def try_migrate_config():
|
||||
try:
|
||||
os.remove("cmd_config.json")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def inject_to_context(context: Context):
|
||||
'''
|
||||
将配置注入到 Context 中。
|
||||
this method returns all the configs
|
||||
'''
|
||||
cc = CmdConfig()
|
||||
|
||||
context.version = VERSION
|
||||
context.base_config = cc.get_all()
|
||||
|
||||
cfg = context.base_config
|
||||
|
||||
if 'reply_prefix' in cfg:
|
||||
# 适配旧版配置
|
||||
if isinstance(cfg['reply_prefix'], dict):
|
||||
context.reply_prefix = ""
|
||||
cfg['reply_prefix'] = ""
|
||||
cc.put("reply_prefix", "")
|
||||
else:
|
||||
context.reply_prefix = cfg['reply_prefix']
|
||||
|
||||
default_personality_str = cc.get("default_personality_str", "")
|
||||
if default_personality_str == "":
|
||||
context.default_personality = None
|
||||
else:
|
||||
context.default_personality = {
|
||||
"name": "default",
|
||||
"prompt": default_personality_str,
|
||||
}
|
||||
|
||||
if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']:
|
||||
context.unique_session = True
|
||||
else:
|
||||
context.unique_session = False
|
||||
|
||||
nick_qq = cc.get("nick_qq", None)
|
||||
if nick_qq == None:
|
||||
nick_qq = ("/", )
|
||||
if isinstance(nick_qq, str):
|
||||
nick_qq = (nick_qq, )
|
||||
if isinstance(nick_qq, list):
|
||||
nick_qq = tuple(nick_qq)
|
||||
context.nick = nick_qq
|
||||
context.t2i_mode = cc.get("qq_pic_mode", True)
|
||||
|
||||
return cfg
|
||||
pass
|
||||
+8
-1
@@ -36,7 +36,14 @@ class MetricUploader():
|
||||
|
||||
for plugin in context.cached_plugins:
|
||||
self.plugin_stats[plugin.metadata.plugin_name] = {
|
||||
"metadata": plugin.metadata
|
||||
"metadata": {
|
||||
"plugin_name": plugin.metadata.plugin_name,
|
||||
"plugin_type": plugin.metadata.plugin_type.value,
|
||||
"author": plugin.metadata.author,
|
||||
"desc": plugin.metadata.desc,
|
||||
"version": plugin.metadata.version,
|
||||
"repo": plugin.metadata.repo,
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
@@ -34,7 +34,12 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
if delay: time.sleep(delay)
|
||||
py = sys.executable
|
||||
self.terminate_child_processes()
|
||||
os.execl(py, py, *sys.argv)
|
||||
py = py.replace(" ", "\\ ")
|
||||
try:
|
||||
os.execl(py, py, *sys.argv)
|
||||
except Exception as e:
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
|
||||
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
return super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
|
||||
Reference in New Issue
Block a user