fix: 优化初始化、消息处理时的配置读取过程,减少性能损耗
This commit is contained in:
+19
-7
@@ -10,6 +10,7 @@ from model.plugin.manager import PluginManager
|
||||
from model.platform.manager import PlatformManager
|
||||
from typing import Dict, List, Union
|
||||
from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
@@ -23,15 +24,26 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
self.config_helper: CmdConfig = CmdConfig()
|
||||
self.config_helper = CmdConfig()
|
||||
|
||||
# load configs and ensure the backward compatibility
|
||||
init_configs()
|
||||
try_migrate_config()
|
||||
self.configs = inject_to_context(self.context)
|
||||
logger.info("AstrBot v" + self.context.version)
|
||||
self.context.config_helper = self.config_helper
|
||||
self.context.base_config = self.config_helper.cached_config
|
||||
|
||||
self.context.default_personality = {
|
||||
"name": "default",
|
||||
"prompt": self.context.base_config.get("default_personality_str", ""),
|
||||
}
|
||||
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
|
||||
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
|
||||
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
|
||||
self.context.nick = nick_qq
|
||||
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
|
||||
self.context.version = VERSION
|
||||
|
||||
logger.info("AstrBot v" + self.context.version)
|
||||
|
||||
# apply proxy settings
|
||||
http_proxy = self.context.base_config.get("http_proxy")
|
||||
https_proxy = self.context.base_config.get("https_proxy")
|
||||
@@ -93,9 +105,9 @@ class AstrBotBootstrap():
|
||||
await asyncio.sleep(5)
|
||||
|
||||
def load_llm(self):
|
||||
if 'openai' in self.configs and \
|
||||
len(self.configs['openai']['key']) and \
|
||||
self.configs['openai']['key'][0] is not None:
|
||||
if 'openai' in self.config_helper.cached_config and \
|
||||
len(self.config_helper.cached_config['openai']['key']) and \
|
||||
self.config_helper.cached_config['openai']['key'][0] is not None:
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
|
||||
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
|
||||
|
||||
@@ -114,7 +114,7 @@ class MessageHandler():
|
||||
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
|
||||
self.nicks = self.context.nick
|
||||
self.provider = provider
|
||||
self.reply_prefix = self.context.reply_prefix
|
||||
self.reply_prefix = str(self.context.reply_prefix)
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
|
||||
+1
-2
@@ -86,12 +86,11 @@ class AstrBotDashBoard():
|
||||
|
||||
@self.dashboard_be.post("/api/change_password")
|
||||
def change_password():
|
||||
password = self.context.base_config("dashboard_password", "")
|
||||
password = self.context.base_config.get("dashboard_password", "")
|
||||
# 获得请求体
|
||||
post_data = request.json
|
||||
if post_data["password"] == password:
|
||||
self.context.config_helper.put("dashboard_password", post_data["new_password"])
|
||||
self.context.base_config['dashboard_password'] = post_data["new_password"]
|
||||
return Response(
|
||||
status="success",
|
||||
message="修改成功。",
|
||||
|
||||
@@ -230,15 +230,17 @@ class InternalCommandHandler:
|
||||
)
|
||||
|
||||
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
|
||||
p = context.config_helper.get("qq_pic_mode", True)
|
||||
p = context.t2i_mode
|
||||
if p:
|
||||
context.config_helper.put("qq_pic_mode", False)
|
||||
context.t2i_mode = False
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain="已关闭文本转图片模式。",
|
||||
)
|
||||
context.config_helper.put("qq_pic_mode", True)
|
||||
context.t2i_mode = True
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
|
||||
@@ -117,8 +117,8 @@ class AIOCQHTTP(Platform):
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.sender.user_id)
|
||||
if sender_id == self.context.config_helper.get('admin_qq', '') or \
|
||||
sender_id in self.context.config_helper.get('other_admins', []):
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -154,7 +154,7 @@ class AIOCQHTTP(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
|
||||
@@ -112,8 +112,8 @@ class QQGOCQ(Platform):
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.raw_message.user_id)
|
||||
if sender_id == self.context.config_helper.get('admin_qq', '') or \
|
||||
sender_id in self.context.config_helper.get('other_admins', []):
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -152,7 +152,7 @@ class QQGOCQ(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
@@ -186,7 +186,7 @@ class QQGOCQ(Platform):
|
||||
plain_text_len += len(i.text)
|
||||
elif isinstance(i, Image):
|
||||
image_num += 1
|
||||
if plain_text_len > self.context.config_helper.get('qq_forward_threshold', 200):
|
||||
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
|
||||
# 删除At
|
||||
for i in message_chain:
|
||||
if isinstance(i, At):
|
||||
|
||||
@@ -209,8 +209,8 @@ class QQOfficial(Platform):
|
||||
|
||||
# 解析出 role
|
||||
sender_id = message.sender.user_id
|
||||
if sender_id == self.context.config_helper.get('admin_qqchan', None) or \
|
||||
sender_id in self.context.config_helper.get('other_admins', None):
|
||||
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
|
||||
sender_id in self.context.base_config.get('other_admins', None):
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -249,7 +249,7 @@ class QQOfficial(Platform):
|
||||
msg_ref = None
|
||||
rendered_images = []
|
||||
|
||||
if self.context.config_helper.get("qq_pic_mode", False) and isinstance(result_message, list):
|
||||
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(result_message)
|
||||
|
||||
if isinstance(result_message, list):
|
||||
|
||||
@@ -53,7 +53,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
os.makedirs("data/openai", exist_ok=True)
|
||||
|
||||
self.cc = CmdConfig
|
||||
self.context = context
|
||||
self.key_data_path = "data/openai/keys.json"
|
||||
self.api_keys = []
|
||||
self.chosen_api_key = None
|
||||
@@ -78,7 +78,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
)
|
||||
self.model_configs: Dict = cfg['chatGPTConfigs']
|
||||
super().set_curr_model(self.model_configs['model'])
|
||||
self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None)
|
||||
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
|
||||
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||
self.session_memory_lock = threading.Lock()
|
||||
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
||||
@@ -492,7 +492,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model_configs['model'] = model
|
||||
self.cc.put_by_dot_str("openai.chatGPTConfigs.model", model)
|
||||
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
|
||||
super().set_curr_model(model)
|
||||
|
||||
def get_configs(self):
|
||||
|
||||
+75
-1
@@ -1 +1,75 @@
|
||||
VERSION = '3.3.5'
|
||||
VERSION = '3.3.7'
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"qqbot": {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"token": "",
|
||||
},
|
||||
"gocqbot": {
|
||||
"enable": False,
|
||||
},
|
||||
"uniqueSessionMode": False,
|
||||
"dump_history_interval": 10,
|
||||
"limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
},
|
||||
"notice": "",
|
||||
"direct_message_mode": True,
|
||||
"reply_prefix": "",
|
||||
"baidu_aip": {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
},
|
||||
"openai": {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"chatGPTConfigs": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
"total_tokens_limit": 10000,
|
||||
},
|
||||
"qq_forward_threshold": 200,
|
||||
"qq_welcome": "",
|
||||
"qq_pic_mode": True,
|
||||
"gocq_host": "127.0.0.1",
|
||||
"gocq_http_port": 5700,
|
||||
"gocq_websocket_port": 6700,
|
||||
"gocq_react_group": True,
|
||||
"gocq_react_guild": True,
|
||||
"gocq_react_friend": True,
|
||||
"gocq_react_group_increase": True,
|
||||
"other_admins": [],
|
||||
"CHATGPT_BASE_URL": "",
|
||||
"qqbot_secret": "",
|
||||
"qqofficial_enable_group_message": False,
|
||||
"admin_qq": "",
|
||||
"nick_qq": ["/", "!"],
|
||||
"admin_qqchan": "",
|
||||
"llm_env_prompt": "",
|
||||
"llm_wake_prefix": "",
|
||||
"default_personality_str": "",
|
||||
"openai_image_generate": {
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
},
|
||||
"http_proxy": "",
|
||||
"https_proxy": "",
|
||||
"dashboard_username": "",
|
||||
"dashboard_password": "",
|
||||
"aiocqhttp": {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 0,
|
||||
}
|
||||
}
|
||||
+6
-5
@@ -28,21 +28,22 @@ class Context:
|
||||
|
||||
self.unique_session = False # 独立会话
|
||||
self.version: str = None # 机器人版本
|
||||
self.nick = None # gocq 的唤醒词
|
||||
self.stat = {}
|
||||
self.nick: tuple = None # gocq 的唤醒词
|
||||
self.t2i_mode = False
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
self.reply_prefix = ""
|
||||
|
||||
self.metrics_uploader = None
|
||||
self.updator: AstrBotUpdator = None
|
||||
self.plugin_updator: PluginUpdator = None
|
||||
self.metrics_uploader = None
|
||||
|
||||
self.plugin_command_bridge = PluginCommandBridge(self.cached_plugins)
|
||||
self.image_renderer = TextToImageRenderer()
|
||||
self.image_uploader = ImageUploader()
|
||||
self.message_handler = None # see astrbot/message/handler.py
|
||||
self.ext_tasks: List[Task] = []
|
||||
|
||||
# useless
|
||||
self.reply_prefix = ""
|
||||
|
||||
def register_commands(self,
|
||||
plugin_name: str,
|
||||
command_name: str,
|
||||
|
||||
+38
-29
@@ -1,19 +1,31 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Union
|
||||
from type.config import DEFAULT_CONFIG
|
||||
|
||||
cpath = "data/cmd_config.json"
|
||||
|
||||
def check_exist():
|
||||
if not os.path.exists(cpath):
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump({}, f, indent=4, ensure_ascii=False)
|
||||
json.dump({}, f, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
class CmdConfig():
|
||||
def __init__(self) -> None:
|
||||
self.cached_config: dict = {}
|
||||
self.init_configs()
|
||||
|
||||
def init_configs(self):
|
||||
'''
|
||||
初始化必需的配置项
|
||||
'''
|
||||
self.init_config_items(DEFAULT_CONFIG)
|
||||
|
||||
@staticmethod
|
||||
def get(key, default=None):
|
||||
'''
|
||||
从文件系统中直接获取配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
@@ -22,28 +34,33 @@ class CmdConfig():
|
||||
else:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def get_all():
|
||||
def get_all(self):
|
||||
'''
|
||||
从文件系统中获取所有配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
return json.load(f)
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
conf = json.loads(conf_str)
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def put(key, value):
|
||||
check_exist()
|
||||
def put(self, key, value):
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
d[key] = value
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=4, ensure_ascii=False)
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
self.cached_config[key] = value
|
||||
|
||||
@staticmethod
|
||||
def put_by_dot_str(key: str, value):
|
||||
'''
|
||||
根据点分割的字符串,将值写入配置文件
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
_d = d
|
||||
@@ -54,30 +71,22 @@ class CmdConfig():
|
||||
else:
|
||||
_d = _d[_ks[i]]
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=4, ensure_ascii=False)
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
@staticmethod
|
||||
def init_attributes(key: Union[str, list], init_val=""):
|
||||
check_exist()
|
||||
conf_str = ''
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'):
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
d = json.loads(conf_str)
|
||||
def init_config_items(self, d: dict):
|
||||
conf = self.get_all()
|
||||
|
||||
if not self.cached_config:
|
||||
self.cached_config = conf
|
||||
|
||||
_tag = False
|
||||
|
||||
if isinstance(key, str):
|
||||
if key not in d:
|
||||
d[key] = init_val
|
||||
for key, val in d.items():
|
||||
if key not in conf:
|
||||
conf[key] = val
|
||||
_tag = True
|
||||
elif isinstance(key, list):
|
||||
for k in key:
|
||||
if k not in d:
|
||||
d[k] = init_val
|
||||
_tag = True
|
||||
if _tag:
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=4, ensure_ascii=False)
|
||||
json.dump(conf, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
+1
-132
@@ -1,89 +1,5 @@
|
||||
import json, os
|
||||
from util.cmd_config import CmdConfig
|
||||
from type.config import VERSION
|
||||
from type.types import Context
|
||||
|
||||
def init_configs():
|
||||
'''
|
||||
初始化必需的配置项
|
||||
'''
|
||||
cc = CmdConfig()
|
||||
|
||||
cc.init_attributes("qqbot", {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"token": "",
|
||||
})
|
||||
cc.init_attributes("gocqbot", {
|
||||
"enable": False,
|
||||
})
|
||||
cc.init_attributes("uniqueSessionMode", False)
|
||||
cc.init_attributes("dump_history_interval", 10)
|
||||
cc.init_attributes("limit", {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
})
|
||||
cc.init_attributes("notice", "")
|
||||
cc.init_attributes("direct_message_mode", True)
|
||||
cc.init_attributes("reply_prefix", "")
|
||||
cc.init_attributes("baidu_aip", {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": ""
|
||||
})
|
||||
cc.init_attributes("openai", {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"chatGPTConfigs": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
"total_tokens_limit": 10000,
|
||||
})
|
||||
|
||||
|
||||
cc.init_attributes("qq_forward_threshold", 200)
|
||||
cc.init_attributes("qq_welcome", "")
|
||||
cc.init_attributes("qq_pic_mode", True)
|
||||
cc.init_attributes("gocq_host", "127.0.0.1")
|
||||
cc.init_attributes("gocq_http_port", 5700)
|
||||
cc.init_attributes("gocq_websocket_port", 6700)
|
||||
cc.init_attributes("gocq_react_group", True)
|
||||
cc.init_attributes("gocq_react_guild", True)
|
||||
cc.init_attributes("gocq_react_friend", True)
|
||||
cc.init_attributes("gocq_react_group_increase", True)
|
||||
cc.init_attributes("other_admins", [])
|
||||
cc.init_attributes("CHATGPT_BASE_URL", "")
|
||||
cc.init_attributes("qqbot_secret", "")
|
||||
cc.init_attributes("qqofficial_enable_group_message", False)
|
||||
cc.init_attributes("admin_qq", "")
|
||||
cc.init_attributes("nick_qq", ["!", "!", "ai"])
|
||||
cc.init_attributes("admin_qqchan", "")
|
||||
cc.init_attributes("llm_env_prompt", "")
|
||||
cc.init_attributes("llm_wake_prefix", "")
|
||||
cc.init_attributes("default_personality_str", "")
|
||||
cc.init_attributes("openai_image_generate", {
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
})
|
||||
cc.init_attributes("http_proxy", "")
|
||||
cc.init_attributes("https_proxy", "")
|
||||
cc.init_attributes("dashboard_username", "")
|
||||
cc.init_attributes("dashboard_password", "")
|
||||
|
||||
# aiocqhttp 适配器
|
||||
cc.init_attributes("aiocqhttp", {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 0,
|
||||
})
|
||||
|
||||
def try_migrate_config():
|
||||
'''
|
||||
@@ -97,51 +13,4 @@ def try_migrate_config():
|
||||
try:
|
||||
os.remove("cmd_config.json")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def inject_to_context(context: Context):
|
||||
'''
|
||||
将配置注入到 Context 中。
|
||||
this method returns all the configs
|
||||
'''
|
||||
cc = CmdConfig()
|
||||
|
||||
context.version = VERSION
|
||||
context.base_config = cc.get_all()
|
||||
|
||||
cfg = context.base_config
|
||||
|
||||
if 'reply_prefix' in cfg:
|
||||
# 适配旧版配置
|
||||
if isinstance(cfg['reply_prefix'], dict):
|
||||
context.reply_prefix = ""
|
||||
cfg['reply_prefix'] = ""
|
||||
cc.put("reply_prefix", "")
|
||||
else:
|
||||
context.reply_prefix = cfg['reply_prefix']
|
||||
|
||||
default_personality_str = cc.get("default_personality_str", "")
|
||||
if default_personality_str == "":
|
||||
context.default_personality = None
|
||||
else:
|
||||
context.default_personality = {
|
||||
"name": "default",
|
||||
"prompt": default_personality_str,
|
||||
}
|
||||
|
||||
if 'uniqueSessionMode' in cfg and cfg['uniqueSessionMode']:
|
||||
context.unique_session = True
|
||||
else:
|
||||
context.unique_session = False
|
||||
|
||||
nick_qq = cc.get("nick_qq", None)
|
||||
if nick_qq == None:
|
||||
nick_qq = ("/", )
|
||||
if isinstance(nick_qq, str):
|
||||
nick_qq = (nick_qq, )
|
||||
if isinstance(nick_qq, list):
|
||||
nick_qq = tuple(nick_qq)
|
||||
context.nick = nick_qq
|
||||
context.t2i_mode = cc.get("qq_pic_mode", True)
|
||||
|
||||
return cfg
|
||||
pass
|
||||
Reference in New Issue
Block a user