Compare commits

..

17 Commits

Author SHA1 Message Date
Soulter 14dbdb2d83 feat: 插件支持正则匹配 2024-08-05 12:12:00 -04:00
Soulter abda226d63 Merge pull request #183 from irorange27/master
fix: fix logo syntax warning
2024-08-05 23:37:57 +08:00
niina a2dc6f0a49 fix: fix logo syntax warning 2024-08-05 22:53:45 +08:00
Soulter 7a94c26333 fix: 修复 wake 唤醒词无法触发 command 的问题 2024-08-05 05:02:57 -04:00
Soulter 9b1ffb384b perf: 优化aiocqhttp适配器的异常处理 2024-08-05 04:46:12 -04:00
Soulter 9566bfe122 workaround for issue #181 2024-08-03 17:03:38 +08:00
Soulter 89ff103bda chore: Add mimetypes workaround for issue #188 2024-08-03 17:02:45 +08:00
Soulter 6c788db53a Merge remote-tracking branch 'refs/remotes/origin/master' 2024-08-03 16:17:25 +08:00
Soulter 344b5fa419 fix: f-string eror 2024-08-03 16:17:04 +08:00
Soulter c6d161b837 Update README.md 2024-08-03 15:04:20 +08:00
Soulter 2065ba0c60 Update README.md 2024-08-03 01:05:27 +08:00
Soulter a481fd1a3e fix: Strip leading and trailing whitespace from llm_wake_prefix 2024-08-02 23:17:35 +08:00
Soulter c50bcdbdb9 fix: Register command only if plugin is found 2024-08-02 22:48:04 +08:00
Soulter 36a2a7632c fix: 优化初始化、消息处理时的配置读取过程,减少性能损耗 2024-07-31 23:38:31 +08:00
Soulter e77b7014e6 fix: 修复更新、卸载插件时的报错 2024-07-30 09:15:45 +08:00
Soulter d57fd0f827 fix: metadata is not seralizable 2024-07-29 09:47:42 +08:00
Soulter 6a83d2a62a update version 2024-07-28 12:11:07 +08:00
20 changed files with 251 additions and 220 deletions
+2 -2
View File
@@ -12,9 +12,9 @@
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
</a>
<a href="https://astrbot.soulter.top/center">项目部署</a>
<a href="https://astrbot.soulter.top/docs/main">快速开始</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
<a href="https://astrbot.soulter.top/center/docs/%E5%BC%80%E5%8F%91/%E6%8F%92%E4%BB%B6%E5%BC%80%E5%8F%91">插件开发</a>
<a href="https://astrbot.soulter.top/docs/develop/plugin4p">插件开发</a>
</div>
## 🛠️ 功能
+19 -7
View File
@@ -10,6 +10,7 @@ from model.plugin.manager import PluginManager
from model.platform.manager import PlatformManager
from typing import Dict, List, Union
from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.cmd_config import CmdConfig
@@ -23,15 +24,26 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap():
def __init__(self) -> None:
self.context = Context()
self.config_helper: CmdConfig = CmdConfig()
self.config_helper = CmdConfig()
# load configs and ensure the backward compatibility
init_configs()
try_migrate_config()
self.configs = inject_to_context(self.context)
logger.info("AstrBot v" + self.context.version)
self.context.config_helper = self.config_helper
self.context.base_config = self.config_helper.cached_config
self.context.default_personality = {
"name": "default",
"prompt": self.context.base_config.get("default_personality_str", ""),
}
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
self.context.nick = nick_qq
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
self.context.version = VERSION
logger.info("AstrBot v" + self.context.version)
# apply proxy settings
http_proxy = self.context.base_config.get("http_proxy")
https_proxy = self.context.base_config.get("https_proxy")
@@ -93,9 +105,9 @@ class AstrBotBootstrap():
await asyncio.sleep(5)
def load_llm(self):
if 'openai' in self.configs and \
len(self.configs['openai']['key']) and \
self.configs['openai']['key'][0] is not None:
if 'openai' in self.config_helper.cached_config and \
len(self.config_helper.cached_config['openai']['key']) and \
self.config_helper.cached_config['openai']['key'][0] is not None:
from model.provider.openai_official import ProviderOpenAIOfficial
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
+8 -3
View File
@@ -112,9 +112,11 @@ class MessageHandler():
self.rate_limit_helper = RateLimitHelper(context)
self.content_safety_helper = ContentSafetyHelper(context)
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
if self.llm_wake_prefix:
self.llm_wake_prefix = self.llm_wake_prefix.strip()
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = self.context.reply_prefix
self.reply_prefix = str(self.context.reply_prefix)
def set_provider(self, provider: Provider):
self.provider = provider
@@ -144,7 +146,8 @@ class MessageHandler():
if msg_plain.startswith(nick):
msg_plain = msg_plain.removeprefix(nick)
break
message.message_str = msg_plain
# scan candidate commands
cmd_res = await self.command_manager.scan_command(message, self.context)
if cmd_res:
@@ -156,10 +159,12 @@ class MessageHandler():
)
# check if the message is a llm-wake-up command
if not msg_plain.startswith(self.llm_wake_prefix):
if self.llm_wake_prefix and not msg_plain.startswith(self.llm_wake_prefix):
logger.debug(f"消息 `{msg_plain}` 没有以 LLM 唤醒前缀 `{self.llm_wake_prefix}` 开头,忽略。")
return
if not provider:
logger.debug("没有任何 LLM 可用,忽略。")
return
# check the content safety
+5 -2
View File
@@ -45,6 +45,10 @@ class AstrBotDashBoard():
def index():
# 返回页面
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/auth/login")
def _():
return self.dashboard_be.send_static_file("index.html")
@self.dashboard_be.get("/config")
def rt_config():
@@ -86,12 +90,11 @@ class AstrBotDashBoard():
@self.dashboard_be.post("/api/change_password")
def change_password():
password = self.context.base_config("dashboard_password", "")
password = self.context.base_config.get("dashboard_password", "")
# 获得请求体
post_data = request.json
if post_data["password"] == password:
self.context.config_helper.put("dashboard_password", post_data["new_password"])
self.context.base_config['dashboard_password'] = post_data["new_password"]
return Response(
status="success",
message="修改成功。",
+7 -1
View File
@@ -4,12 +4,13 @@ import asyncio
import sys
import warnings
import traceback
import mimetypes
from astrbot.bootstrap import AstrBotBootstrap
from SparkleLogging.utils.core import LogManager
from logging import Formatter
warnings.filterwarnings("ignore")
logo_tmpl = """
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
@@ -42,6 +43,11 @@ def check_env():
os.makedirs("data/config", exist_ok=True)
os.makedirs("temp", exist_ok=True)
# workaround for issue #181
mimetypes.add_type("text/javascript", ".js")
mimetypes.add_type("text/javascript", ".mjs")
mimetypes.add_type("application/json", ".json")
if __name__ == "__main__":
check_env()
+8 -4
View File
@@ -62,14 +62,16 @@ class InternalCommandHandler:
return CommandResult().message("你没有权限使用该指令。")
l = message_str.split(" ")
if len(l) == 1:
return CommandResult().message("设置机器人唤醒词,支持多唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称1 昵称2 昵称3")
nick = l[1:]
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}")
nick = l[1].strip()
if not nick:
return CommandResult().message("wake: 请指定唤醒词。")
context.config_helper.put("nick_qq", nick)
context.nick = tuple(nick)
return CommandResult(
hit=True,
success=True,
message_chain=f"已经成功将唤醒词设定为 {nick}",
message_chain=f"已经成功将唤醒词设定为 {nick}",
)
def update(self, message: AstrMessageEvent, context: Context):
@@ -230,15 +232,17 @@ class InternalCommandHandler:
)
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
p = context.config_helper.get("qq_pic_mode", True)
p = context.t2i_mode
if p:
context.config_helper.put("qq_pic_mode", False)
context.t2i_mode = False
return CommandResult(
hit=True,
success=True,
message_chain="已关闭文本转图片模式。",
)
context.config_helper.put("qq_pic_mode", True)
context.t2i_mode = True
return CommandResult(
hit=True,
+19 -3
View File
@@ -20,7 +20,8 @@ class CommandMetadata():
inner_command: bool
plugin_metadata: PluginMetadata
handler: callable
description: str
use_regex: bool = False
description: str = ""
class CommandManager():
def __init__(self):
@@ -33,10 +34,13 @@ class CommandManager():
description: str,
priority: int,
handler: callable,
use_regex: bool = False,
plugin_metadata: PluginMetadata = None,
):
'''
优先级越高,越先被处理。
use_regex: 是否使用正则表达式匹配指令。
'''
if command in self.commands_handler:
raise ValueError(f"Command {command} already exists.")
@@ -48,6 +52,7 @@ class CommandManager():
inner_command=plugin_metadata == None,
plugin_metadata=plugin_metadata,
handler=handler,
use_regex=use_regex,
description=description
)
if plugin_metadata:
@@ -64,13 +69,24 @@ class CommandManager():
break
if not plugin:
logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}")
self.register(request.command_name, request.description, request.priority, request.handler, plugin.metadata)
else:
self.register(command=request.command_name,
description=request.description,
priority=request.priority,
handler=request.handler,
use_regex=request.use_regex,
plugin_metadata=plugin.metadata)
self.plugin_commands_waitlist = []
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
message_str = message_event.message_str
for _, command in self.commands:
if message_str.startswith(command):
trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context)
if command_result.hit:
+7 -1
View File
@@ -1,3 +1,5 @@
import re
class CommandTokens():
def __init__(self) -> None:
self.tokens = []
@@ -16,4 +18,8 @@ class CommandParser():
cmd_tokens = CommandTokens()
cmd_tokens.tokens = message.split(" ")
cmd_tokens.len = len(cmd_tokens.tokens)
return cmd_tokens
return cmd_tokens
def regex_match(self, message: str, command: str) -> bool:
return re.search(command, message, re.MULTILINE) is not None
+11 -5
View File
@@ -48,6 +48,14 @@ class AIOCQHTTP(Platform):
abm.message = []
message_str = ""
if not isinstance(event.message, list):
err = f"aiocqhttp: 无法识别的消息类型: {str(event.message)},此条消息将被忽略。如果您在使用 go-cqhttp,请将其配置文件中的 message.post-format 更改为 array。"
logger.critical(err)
try:
self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return
for m in event.message:
t = m['type']
a = None
@@ -75,14 +83,12 @@ class AIOCQHTTP(Platform):
abm = self.convert_message(event)
if abm:
await self.handle_msg(abm)
# return {'reply': event.message}
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
if abm:
await self.handle_msg(abm)
# return {'reply': event.message}
bot = self.bot.run_task(host=self.host, port=int(self.port), shutdown_trigger=self.shutdown_trigger_placeholder)
@@ -117,8 +123,8 @@ class AIOCQHTTP(Platform):
# 解析 role
sender_id = str(message.sender.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \
sender_id in self.context.config_helper.get('other_admins', []):
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
role = 'admin'
else:
role = 'member'
@@ -154,7 +160,7 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
+4 -4
View File
@@ -112,8 +112,8 @@ class QQGOCQ(Platform):
# 解析 role
sender_id = str(message.raw_message.user_id)
if sender_id == self.context.config_helper.get('admin_qq', '') or \
sender_id in self.context.config_helper.get('other_admins', []):
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
role = 'admin'
else:
role = 'member'
@@ -152,7 +152,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -186,7 +186,7 @@ class QQGOCQ(Platform):
plain_text_len += len(i.text)
elif isinstance(i, Image):
image_num += 1
if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200):
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
# 删除At
for i in message_chain:
if isinstance(i, At):
+3 -3
View File
@@ -209,8 +209,8 @@ class QQOfficial(Platform):
# 解析出 role
sender_id = message.sender.user_id
if sender_id == self.context.config_helper.get('admin_qqchan', None) or \
sender_id in self.context.config_helper.get('other_admins', None):
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
sender_id in self.context.base_config.get('other_admins', None):
role = 'admin'
else:
role = 'member'
@@ -249,7 +249,7 @@ class QQOfficial(Platform):
msg_ref = None
rendered_images = []
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list):
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list):
+3 -2
View File
@@ -13,6 +13,7 @@ class CommandRegisterRequest():
description: str
priority: int
handler: Callable
use_regex: bool = False
plugin_name: str = None
class PluginCommandBridge():
@@ -20,6 +21,6 @@ class PluginCommandBridge():
self.plugin_commands_waitlist: List[CommandRegisterRequest] = []
self.cached_plugins = cached_plugins
def register_command(self, plugin_name, command_name, description, priority, handler):
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, plugin_name))
def register_command(self, plugin_name, command_name, description, priority, handler, use_regex=False):
self.plugin_commands_waitlist.append(CommandRegisterRequest(command_name, description, priority, handler, use_regex, plugin_name))
+7 -3
View File
@@ -123,7 +123,7 @@ class PluginManager():
return p
def uninstall_plugin(self, plugin_name: str):
plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins)
plugin = self.get_registered_plugin(plugin_name)
if not plugin:
raise Exception("插件不存在。")
root_dir_name = plugin.root_dir_name
@@ -133,7 +133,7 @@ class PluginManager():
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
def update_plugin(self, plugin_name: str):
plugin = self.get_registered_plugin(plugin_name, self.context.cached_plugins)
plugin = self.get_registered_plugin(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -155,6 +155,8 @@ class PluginManager():
p = plugin['module']
module_path = plugin['module_path']
root_dir_name = plugin['pname']
logger.info(f"正在加载插件 {root_dir_name} ...")
# self.check_plugin_dept_update(cached_plugins, root_dir_name)
@@ -166,8 +168,10 @@ class PluginManager():
try:
# 尝试传入 ctx
obj = getattr(module, cls[0])(context=self.context)
except:
except TypeError:
obj = getattr(module, cls[0])()
except BaseException as e:
raise e
metadata = None
+6 -4
View File
@@ -53,7 +53,7 @@ class ProviderOpenAIOfficial(Provider):
os.makedirs("data/openai", exist_ok=True)
self.cc = CmdConfig
self.context = context
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
@@ -78,7 +78,7 @@ class ProviderOpenAIOfficial(Provider):
)
self.model_configs: Dict = cfg['chatGPTConfigs']
super().set_curr_model(self.model_configs['model'])
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
@@ -385,7 +385,9 @@ class ProviderOpenAIOfficial(Provider):
assert isinstance(completion, ChatCompletion)
logger.debug(f"openai completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("OpenAI API 返回的 completion 为空。")
choice = completion.choices[0]
usage_tokens = completion.usage.total_tokens
@@ -492,7 +494,7 @@ class ProviderOpenAIOfficial(Provider):
def set_model(self, model: str):
self.model_configs['model'] = model
self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model)
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
super().set_curr_model(model)
def get_configs(self):
+75 -1
View File
@@ -1 +1,75 @@
VERSION = '3.3.4'
VERSION = '3.3.7'
DEFAULT_CONFIG = {
"qqbot": {
"enable": False,
"appid": "",
"token": "",
},
"gocqbot": {
"enable": False,
},
"uniqueSessionMode": False,
"dump_history_interval": 10,
"limit": {
"time": 60,
"count": 30,
},
"notice": "",
"direct_message_mode": True,
"reply_prefix": "",
"baidu_aip": {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
},
"openai": {
"key": [],
"api_base": "",
"chatGPTConfigs": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"total_tokens_limit": 10000,
},
"qq_forward_threshold": 200,
"qq_welcome": "",
"qq_pic_mode": True,
"gocq_host": "127.0.0.1",
"gocq_http_port": 5700,
"gocq_websocket_port": 6700,
"gocq_react_group": True,
"gocq_react_guild": True,
"gocq_react_friend": True,
"gocq_react_group_increase": True,
"other_admins": [],
"CHATGPT_BASE_URL": "",
"qqbot_secret": "",
"qqofficial_enable_group_message": False,
"admin_qq": "",
"nick_qq": ["/", "!"],
"admin_qqchan": "",
"llm_env_prompt": "",
"llm_wake_prefix": "",
"default_personality_str": "",
"openai_image_generate": {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
},
"http_proxy": "",
"https_proxy": "",
"dashboard_username": "",
"dashboard_password": "",
"aiocqhttp": {
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 0,
}
}
+15 -12
View File
@@ -28,37 +28,40 @@ class Context:
self.unique_session = False # 独立会话
self.version: str = None # 机器人版本
self.nick = None # gocq 的唤醒词
self.stat = {}
self.nick: tuple = None # gocq 的唤醒词
self.t2i_mode = False
self.web_search = False # 是否开启了网页搜索
self.reply_prefix = ""
self.metrics_uploader = None
self.updator: AstrBotUpdator = None
self.plugin_updator: PluginUpdator = None
self.metrics_uploader = None
self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins)
self.image_renderer = TextToImageRenderer()
self.image_uploader = ImageUploader()
self.message_handler = None # see astrbot/message/handler.py
self.ext_tasks: List[Task] = []
# useless
self.reply_prefix = ""
def register_commands(self,
plugin_name: str,
command_name: str,
description: str,
priority: int,
handler: callable):
handler: callable,
use_regex: bool = False):
'''
注册插件指令。
`plugin_name`: 插件名,注意需要和你的 metadata 中的一致。
`command_name`: 指令名,如 "help"。不需要带前缀。
`description`: 指令描述。
`priority`: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
`handler`: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param plugin_name: 插件名,注意需要和你的 metadata 中的一致。
@param command_name: 指令名,如 "help"。不需要带前缀。
@param description: 指令描述。
@param priority: 优先级越高,越先被处理。合理的优先级应该在 1-10 之间。
@param handler: 指令处理函数。函数参数:message: AstrMessageEvent, context: Context
@param use_regex: 是否使用正则表达式匹配指令名。
'''
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler)
self.plugin_command_bridge.register_command(plugin_name, command_name, description, priority, handler, use_regex)
def register_task(self, coro: Awaitable, task_name: str):
'''
+38 -29
View File
@@ -1,19 +1,31 @@
import os
import json
from typing import Union
from type.config import DEFAULT_CONFIG
cpath = "data/cmd_config.json"
def check_exist():
if not os.path.exists(cpath):
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump({}, f, indent=4, ensure_ascii=False)
json.dump({}, f, ensure_ascii=False)
f.flush()
class CmdConfig():
def __init__(self) -> None:
self.cached_config: dict = {}
self.init_configs()
def init_configs(self):
'''
初始化必需的配置项
'''
self.init_config_items(DEFAULT_CONFIG)
@staticmethod
def get(key, default=None):
'''
从文件系统中直接获取配置
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
d = json.load(f)
@@ -22,28 +34,33 @@ class CmdConfig():
else:
return default
@staticmethod
def get_all():
def get_all(self):
'''
从文件系统中获取所有配置
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
return json.load(f)
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
return conf
@staticmethod
def put(key, value):
check_exist()
def put(self, key, value):
with open(cpath, "r", encoding="utf-8-sig") as f:
d = json.load(f)
d[key] = value
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False)
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
self.cached_config[key] = value
@staticmethod
def put_by_dot_str(key: str, value):
'''
根据点分割的字符串,将值写入配置文件
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
d = json.load(f)
_d = d
@@ -54,30 +71,22 @@ class CmdConfig():
else:
_d = _d[_ks[i]]
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False)
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
@staticmethod
def init_attributes(key: Union[str, list], init_val=""):
check_exist()
conf_str = ''
with open(cpath, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'):
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
d = json.loads(conf_str)
def init_config_items(self, d: dict):
conf = self.get_all()
if not self.cached_config:
self.cached_config = conf
_tag = False
if isinstance(key, str):
if key not in d:
d[key] = init_val
for key, val in d.items():
if key not in conf:
conf[key] = val
_tag = True
elif isinstance(key, list):
for k in key:
if k not in d:
d[k] = init_val
_tag = True
if _tag:
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=4, ensure_ascii=False)
json.dump(conf, f, indent=2, ensure_ascii=False)
f.flush()
+1 -132
View File
@@ -1,89 +1,5 @@
import json, os
from util.cmd_config import CmdConfig
from type.config import VERSION
from type.types import Context
def init_configs():
'''
初始化必需的配置项
'''
cc = CmdConfig()
cc.init_attributes("qqbot", {
"enable": False,
"appid": "",
"token": "",
})
cc.init_attributes("gocqbot", {
"enable": False,
})
cc.init_attributes("uniqueSessionMode", False)
cc.init_attributes("dump_history_interval", 10)
cc.init_attributes("limit", {
"time": 60,
"count": 30,
})
cc.init_attributes("notice", "")
cc.init_attributes("direct_message_mode", True)
cc.init_attributes("reply_prefix", "")
cc.init_attributes("baidu_aip", {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": ""
})
cc.init_attributes("openai", {
"key": [],
"api_base": "",
"chatGPTConfigs": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"total_tokens_limit": 10000,
})
cc.init_attributes("qq_forward_threshold", 200)
cc.init_attributes("qq_welcome", "")
cc.init_attributes("qq_pic_mode", True)
cc.init_attributes("gocq_host", "127.0.0.1")
cc.init_attributes("gocq_http_port", 5700)
cc.init_attributes("gocq_websocket_port", 6700)
cc.init_attributes("gocq_react_group", True)
cc.init_attributes("gocq_react_guild", True)
cc.init_attributes("gocq_react_friend", True)
cc.init_attributes("gocq_react_group_increase", True)
cc.init_attributes("other_admins", [])
cc.init_attributes("CHATGPT_BASE_URL", "")
cc.init_attributes("qqbot_secret", "")
cc.init_attributes("qqofficial_enable_group_message", False)
cc.init_attributes("admin_qq", "")
cc.init_attributes("nick_qq", ["!", "", "ai"])
cc.init_attributes("admin_qqchan", "")
cc.init_attributes("llm_env_prompt", "")
cc.init_attributes("llm_wake_prefix", "")
cc.init_attributes("default_personality_str", "")
cc.init_attributes("openai_image_generate", {
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
})
cc.init_attributes("http_proxy", "")
cc.init_attributes("https_proxy", "")
cc.init_attributes("dashboard_username", "")
cc.init_attributes("dashboard_password", "")
# aiocqhttp 适配器
cc.init_attributes("aiocqhttp", {
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 0,
})
def try_migrate_config():
'''
@@ -97,51 +13,4 @@ def try_migrate_config():
try:
os.remove("cmd_config.json")
except Exception as e:
pass
def inject_to_context(context: Context):
'''
将配置注入到 Context 中。
this method returns all the configs
'''
cc = CmdConfig()
context.version = VERSION
context.base_config = cc.get_all()
cfg = context.base_config
if 'reply_prefix' in cfg:
# 适配旧版配置
if isinstance(cfg['reply_prefix'], dict):
context.reply_prefix = ""
cfg['reply_prefix'] = ""
cc.put("reply_prefix", "")
else:
context.reply_prefix = cfg['reply_prefix']
default_personality_str = cc.get("default_personality_str", "")
if default_personality_str == "":
context.default_personality = None
else:
context.default_personality = {
"name": "default",
"prompt": default_personality_str,
}
if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']:
context.unique_session = True
else:
context.unique_session = False
nick_qq = cc.get("nick_qq", None)
if nick_qq == None:
nick_qq = ("/", )
if isinstance(nick_qq, str):
nick_qq = (nick_qq, )
if isinstance(nick_qq, list):
nick_qq = tuple(nick_qq)
context.nick = nick_qq
context.t2i_mode = cc.get("qq_pic_mode", True)
return cfg
pass
+8 -1
View File
@@ -36,7 +36,14 @@ class MetricUploader():
for plugin in context.cached_plugins:
self.plugin_stats[plugin.metadata.plugin_name] = {
"metadata": plugin.metadata
"metadata": {
"plugin_name": plugin.metadata.plugin_name,
"plugin_type": plugin.metadata.plugin_type.value,
"author": plugin.metadata.author,
"desc": plugin.metadata.desc,
"version": plugin.metadata.version,
"repo": plugin.metadata.repo,
}
}
try:
+5 -1
View File
@@ -35,7 +35,11 @@ class AstrBotUpdator(RepoZipUpdator):
py = sys.executable
self.terminate_child_processes()
py = py.replace(" ", "\\ ")
os.execl(py, py, *sys.argv)
try:
os.execl(py, py, *sys.argv)
except Exception as e:
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
raise e
def check_update(self, url: str, current_version: str) -> ReleaseInfo:
return super().check_update(self.ASTRBOT_RELEASE_API, VERSION)