diff --git a/.coverage b/.coverage new file mode 100644 index 000000000..439ed81bb Binary files /dev/null and b/.coverage differ diff --git a/dashboard/server.py b/dashboard/server.py index f1675b1b2..87e9d25c4 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -383,14 +383,14 @@ class AstrBotDashBoard(): def save_astrbot_configs(self, post_configs: dict): try: self.dashboard_helper.save_astrbot_config(post_configs) - threading.Thread(target=self.astrbot_updator._reboot, args=(3, ), daemon=True).start() + threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start() except Exception as e: raise e def save_extension_configs(self, post_configs: dict): try: self.dashboard_helper.save_extension_config(post_configs) - threading.Thread(target=self.astrbot_updator._reboot, args=(3, ), daemon=True).start() + threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start() except Exception as e: raise e diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 8245f984f..e4e120f79 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -8,7 +8,6 @@ from type.types import Context from type.config import VERSION from SparkleLogging.utils.core import LogManager from logging import Logger -from nakuru.entities.components import Image from util.agent.web_searcher import search_from_bing, fetch_website_content logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -208,10 +207,11 @@ class InternalCommandHandler: return CommandResult( hit=True, success=True, - message_chain=f"网页搜索功能当前状态: {context.web_search}", + message_chain=f"网页搜索功能当前状态: {context.config_helper.llm_settings.web_search}", ) elif l[1] == 'on': - context.web_search = True + context.config_helper.llm_settings.web_search = True + context.config_helper.save_config() context.register_llm_tool("web_search", [{ "type": "string", "name": "keyword", @@ -235,7 +235,8 @@ class InternalCommandHandler: message_chain="已开启网页搜索", ) elif l[1] == 'off': - context.web_search = False + context.config_helper.llm_settings.web_search = False + context.config_helper.save_config() context.unregister_llm_tool("web_search") context.unregister_llm_tool("fetch_website_content") @@ -254,15 +255,15 @@ class InternalCommandHandler: def t2i_toggle(self, message: AstrMessageEvent, context: Context): p = context.config_helper.t2i if p: - context.config_helper.put("qq_pic_mode", False) context.config_helper.t2i = False + context.config_helper.save_config() return CommandResult( hit=True, success=True, message_chain="已关闭文本转图片模式。", ) - context.config_helper.put("qq_pic_mode", True) context.config_helper.t2i = True + context.config_helper.save_config() return CommandResult( hit=True, diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index b79f28ff8..c6f4164f8 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -188,7 +188,7 @@ class AIOCQHTTP(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index 3edb69d7a..9b06526b6 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -185,7 +185,7 @@ class QQNakuru(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index e6716532f..f3b97dcb4 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -243,7 +243,7 @@ class QQOfficial(Platform): msg_ref = None rendered_images = [] - if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list): + if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): rendered_images = await self.convert_to_t2i_chain(result_message) if isinstance(result_message, list): diff --git a/tests/test_message.py b/tests/test_message.py index a5fc4578a..809b87164 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -13,6 +13,8 @@ from type.message_event import * from SparkleLogging.utils.core import LogManager from logging import Formatter +from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig + logger = LogManager.GetLogger( log_name='astrbot', out_to_console=True, @@ -24,8 +26,11 @@ os.environ['TEST_MODE'] = 'on' bootstrap = AstrBotBootstrap() asyncio.run(bootstrap.run()) -qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler) -aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler) +for p_config in bootstrap.context.config_helper.platform: + if isinstance(p_config, QQOfficialPlatformConfig): + qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler, p_config) + elif isinstance(p_config, AiocqhttpPlatformConfig): + aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler, p_config) class TestBasicMessageHandle(): @pytest.mark.asyncio diff --git a/type/types.py b/type/types.py index a0dace291..a29e15045 100644 --- a/type/types.py +++ b/type/types.py @@ -21,8 +21,8 @@ class Context: ''' def __init__(self): + self.running = True self.logger: Logger = None - self.base_config: dict = None # 配置(期望启动机器人后是不变的) self.config_helper: AstrBotConfig = None self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件 self.platforms: List[RegisteredPlatform] = [] diff --git a/util/cmd_config.py b/util/cmd_config.py index befc5e9b8..cf85bdb25 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -202,6 +202,10 @@ class AstrBotConfig(): json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False) f.flush() + def save_config(self): + '''将现存配置写入文件''' + self.flush_config(self.to_dict()) + def init_configs(self): '''初始化必需的配置项''' config = None @@ -228,7 +232,7 @@ class AstrBotConfig(): self.load_from_dict(config) - def get(self, key, default=None): + def get(self, key: str, default=None): ''' 从文件系统中直接获取配置 ''' @@ -288,6 +292,5 @@ class AstrBotConfig(): if key not in self: raise KeyError(f"Key {key} not found in config.") - def check_exist(self) -> bool: return os.path.exists(ASTRBOT_CONFIG_PATH) \ No newline at end of file