From f8aef78d259041d5a8a63d75b16b67290d9b289e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 6 Aug 2024 04:58:29 -0400 Subject: [PATCH 1/5] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E6=A0=BC=E5=BC=8F=20perf:=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=A4=84=E7=90=86=E8=BF=87=E7=A8=8B=E5=92=8C?= =?UTF-8?q?=E5=91=88=E7=8E=B0=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 53 ++- astrbot/message/baidu_aip_judge.py | 17 +- astrbot/message/handler.py | 33 +- dashboard/__init__.py | 1 - dashboard/helper.py | 568 +++-------------------------- dashboard/server.py | 140 +++---- model/command/internal_handler.py | 11 +- model/platform/__init__.py | 9 +- model/platform/manager.py | 60 +-- model/platform/qq_aiocqhttp.py | 24 +- model/platform/qq_nakuru.py | 38 +- model/platform/qq_official.py | 53 ++- model/provider/openai_official.py | 51 ++- type/config.py | 277 ++++++++++++++ type/types.py | 16 +- util/cmd_config.py | 287 ++++++++++++--- util/config_utils.py | 19 +- util/metrics.py | 3 +- util/plugin_dev/api/v1/platform.py | 2 +- 19 files changed, 829 insertions(+), 833 deletions(-) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index 0facf2037..81787c904 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -13,7 +13,7 @@ from type.types import Context from type.config import VERSION from SparkleLogging.utils.core import LogManager from logging import Logger -from util.cmd_config import CmdConfig +from util.cmd_config import AstrBotConfig from util.metrics import MetricUploader from util.config_utils import * from util.updator.astrbot_updator import AstrBotUpdator @@ -24,29 +24,15 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') class AstrBotBootstrap(): def __init__(self) -> None: self.context = Context() - self.config_helper = CmdConfig() # load configs and ensure the backward compatibility try_migrate_config() + self.config_helper = AstrBotConfig() self.context.config_helper = self.config_helper - self.context.base_config = self.config_helper.cached_config - - self.context.default_personality = { - "name": "default", - "prompt": self.context.base_config.get("default_personality_str", ""), - } - self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False) - nick_qq = self.context.base_config.get("nick_qq", ('/', '!')) - if isinstance(nick_qq, str): nick_qq = (nick_qq, ) - self.context.nick = nick_qq - self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True) - self.context.version = VERSION - - logger.info("AstrBot v" + self.context.version) - + logger.info("AstrBot v" + VERSION) # apply proxy settings - http_proxy = self.context.base_config.get("http_proxy") - https_proxy = self.context.base_config.get("https_proxy") + http_proxy = self.context.config_helper.http_proxy + https_proxy = self.context.config_helper.https_proxy if http_proxy: os.environ['HTTP_PROXY'] = http_proxy if https_proxy: @@ -66,10 +52,9 @@ class AstrBotBootstrap(): self.db_conn_helper = dbConn() # load llm provider - self.llm_instance: Provider = None self.load_llm() - self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance) + self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper) self.platfrom_manager = PlatformManager(self.context, self.message_handler) self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator) self.metrics_uploader = MetricUploader(self.context) @@ -105,16 +90,26 @@ class AstrBotBootstrap(): await asyncio.sleep(5) def load_llm(self): - if 'openai' in self.config_helper.cached_config and \ - len(self.config_helper.cached_config['openai']['key']) and \ - self.config_helper.cached_config['openai']['key'][0] is not None: - from model.provider.openai_official import ProviderOpenAIOfficial + f = False + llms = self.context.config_helper.llm + logger.info(f"加载 {len(llms)} 个 LLM Provider...") + for llm in llms: + if llm.enable: + if llm.name == "openai" and llm.key and llm.enable: + self.load_openai(llm) + f = True + logger.info(f"已启用 OpenAI API 支持。") + else: + logger.warn(f"未知的 LLM Provider: {llm.name}") + if f: from model.command.openai_official_handler import OpenAIOfficialCommandHandler self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) - self.llm_instance = ProviderOpenAIOfficial(self.context) - self.openai_command_handler.set_provider(self.llm_instance) - self.context.register_provider("internal_openai", self.llm_instance) - logger.info("已启用 OpenAI API 支持。") + self.openai_command_handler.set_provider(self.context.llms[0].llm_instance) + + def load_openai(self, llm_config): + from model.provider.openai_official import ProviderOpenAIOfficial + inst = ProviderOpenAIOfficial(llm_config) + self.context.register_provider("internal_openai", inst) def load_plugins(self): self.plugin_manager.plugin_reload() diff --git a/astrbot/message/baidu_aip_judge.py b/astrbot/message/baidu_aip_judge.py index cd2417d0a..f8d09a860 100644 --- a/astrbot/message/baidu_aip_judge.py +++ b/astrbot/message/baidu_aip_judge.py @@ -1,16 +1,15 @@ from aip import AipContentCensor +from util.cmd_config import BaiduAIPConfig class BaiduJudge: - def __init__(self, baidu_configs) -> None: - if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs: - self.app_id = str(baidu_configs['app_id']) - self.api_key = baidu_configs['api_key'] - self.secret_key = baidu_configs['secret_key'] - self.client = AipContentCensor( - self.app_id, self.api_key, self.secret_key) - else: - raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!") + def __init__(self, baidu_configs: BaiduAIPConfig) -> None: + self.app_id = baidu_configs.app_id + self.api_key = baidu_configs.api_key + self.secret_key = baidu_configs.secret_key + self.client = AipContentCensor(self.app_id, + self.api_key, + self.secret_key) def judge(self, text): res = self.client.textCensorUserDefined(text) diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 12cd57419..45c8cff4a 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -22,16 +22,11 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') class RateLimitHelper(): def __init__(self, context: Context) -> None: self.user_rate_limit: Dict[int, int] = {} - self.rate_limit_time: int = 60 - self.rate_limit_count: int = 10 + rl = context.config_helper.platform_settings.rate_limit + self.rate_limit_time: int = rl.time + self.rate_limit_count: int = rl.count self.user_frequency = {} - - if 'limit' in context.base_config: - if 'count' in context.base_config['limit']: - self.rate_limit_count = context.base_config['limit']['count'] - if 'time' in context.base_config['limit']: - self.rate_limit_time = context.base_config['limit']['time'] - + def check_frequency(self, session_id: str) -> bool: ''' 检查发言频率 @@ -56,12 +51,11 @@ class RateLimitHelper(): class ContentSafetyHelper(): def __init__(self, context: Context) -> None: self.baidu_judge = None - if 'baidu_api' in context.base_config and \ - 'enable' in context.base_config['baidu_aip'] and \ - context.base_config['baidu_aip']['enable']: + aip = context.config_helper.content_safety.baidu_aip + if aip.enable: try: from astrbot.message.baidu_aip_judge import BaiduJudge - self.baidu_judge = BaiduJudge(context.base_config['baidu_aip']) + self.baidu_judge = BaiduJudge(aip) logger.info("已启用百度 AI 内容审核。") except BaseException as e: logger.error("百度 AI 内容审核初始化失败。") @@ -104,19 +98,18 @@ class ContentSafetyHelper(): class MessageHandler(): def __init__(self, context: Context, command_manager: CommandManager, - persist_manager: dbConn, - provider: Provider) -> None: + persist_manager: dbConn) -> None: self.context = context self.command_manager = command_manager self.persist_manager = persist_manager self.rate_limit_helper = RateLimitHelper(context) self.content_safety_helper = ContentSafetyHelper(context) - self.llm_wake_prefix = self.context.base_config['llm_wake_prefix'] + self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix if self.llm_wake_prefix: self.llm_wake_prefix = self.llm_wake_prefix.strip() - self.nicks = self.context.nick - self.provider = provider - self.reply_prefix = str(self.context.reply_prefix) + self.nicks = self.context.config_helper.wake_prefix + self.provider = self.context.llms[0] if len(self.context.llms) > 0 else None + self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix) def set_provider(self, provider: Provider): self.provider = provider @@ -176,7 +169,7 @@ class MessageHandler(): if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file break - web_search = self.context.web_search + web_search = self.context.config_helper.llm_settings.web_search if not web_search and msg_plain.startswith("ws"): # leverage web search feature web_search = True diff --git a/dashboard/__init__.py b/dashboard/__init__.py index 8e4ce699e..9394342dd 100644 --- a/dashboard/__init__.py +++ b/dashboard/__init__.py @@ -2,7 +2,6 @@ from dataclasses import dataclass class DashBoardData(): stats: dict = {} - configs: dict = {} @dataclass class Response(): diff --git a/dashboard/helper.py b/dashboard/helper.py index 4a9a7041a..c9c17f7de 100644 --- a/dashboard/helper.py +++ b/dashboard/helper.py @@ -1,537 +1,59 @@ -import threading -import asyncio - from . import DashBoardData -from typing import Union, Optional -from util.cmd_config import CmdConfig -from dataclasses import dataclass +from util.cmd_config import AstrBotConfig +from dataclasses import dataclass, asdict from util.plugin_dev.api.v1.config import update_config from SparkleLogging.utils.core import LogManager from logging import Logger from type.types import Context +from type.config import CONFIG_METADATA_2 logger: Logger = LogManager.GetLogger(log_name='astrbot') -@dataclass -class DashBoardConfig(): - config_type: str - name: Optional[str] = None - description: Optional[str] = None - path: Optional[str] = None # 仅 item 才需要 - body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要 - value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要 - val_type: Optional[str] = None # 仅 item 才需要 - - class DashBoardHelper(): - def __init__(self, context: Context, dashboard_data: DashBoardData): - dashboard_data.configs = { - "data": [] - } + def __init__(self, context: Context): self.context = context - self.parse_default_config(dashboard_data, context.base_config) + self.config_key_dont_show = ['dashboard', 'config_version'] - # 将 config.yaml、 中的配置解析到 dashboard_data.configs 中 - def parse_default_config(self, dashboard_data: DashBoardData, config: dict): - - try: - qq_official_platform_group = DashBoardConfig( - config_type="group", - name="QQ(官方)", - description="", - body=[ - DashBoardConfig( - config_type="item", - val_type="bool", - name="启用 QQ_OFFICIAL 平台", - description="官方的接口,仅支持 QQ 频道。详见 q.qq.com", - value=config['qqbot']['enable'], - path="qqbot.enable", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="QQ机器人APPID", - description="详见 q.qq.com", - value=config['qqbot']['appid'], - path="qqbot.appid", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="QQ机器人令牌", - description="详见 q.qq.com", - value=config['qqbot']['token'], - path="qqbot.token", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="QQ机器人 Secret", - description="详见 q.qq.com", - value=config['qqbot_secret'], - path="qqbot_secret", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="是否允许 QQ 频道私聊", - description="如果启用,机器人会响应私聊消息。", - value=config['direct_message_mode'], - path="direct_message_mode", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="是否接收QQ群消息", - description="需要机器人有相应的群消息接收权限。在 q.qq.com 上查看。", - value=config['qqofficial_enable_group_message'], - path="qqofficial_enable_group_message", - ), - ] - ) - qq_gocq_platform_group = DashBoardConfig( - config_type="group", - name="QQ(nakuru)", - description="", - body=[ - DashBoardConfig( - config_type="item", - val_type="bool", - name="启用", - description="", - value=config['gocqbot']['enable'], - path="gocqbot.enable", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="HTTP 服务器地址", - description="", - value=config['gocq_host'], - path="gocq_host", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="HTTP 服务器端口", - description="", - value=config['gocq_http_port'], - path="gocq_http_port", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="WebSocket 服务器端口", - description="目前仅支持正向 WebSocket", - value=config['gocq_websocket_port'], - path="gocq_websocket_port", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="是否响应群消息", - description="", - value=config['gocq_react_group'], - path="gocq_react_group", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="是否响应私聊消息", - description="", - value=config['gocq_react_friend'], - path="gocq_react_friend", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="是否响应群成员增加消息", - description="", - value=config['gocq_react_group_increase'], - path="gocq_react_group_increase", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="是否响应频道消息", - description="", - value=config['gocq_react_guild'], - path="gocq_react_guild", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="转发阈值(字符数)", - description="机器人回复的消息长度超出这个值后,会被折叠成转发卡片发出以减少刷屏。", - value=config['qq_forward_threshold'], - path="qq_forward_threshold", - ), - ] - ) - - qq_aiocqhttp_platform_group = DashBoardConfig( - config_type="group", - name="QQ(aiocqhttp)", - description="", - body=[ - DashBoardConfig( - config_type="item", - val_type="bool", - name="启用", - description="", - value=config['aiocqhttp']['enable'], - path="aiocqhttp.enable", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="WebSocket 反向连接 host", - description="", - value=config['aiocqhttp']['ws_reverse_host'], - path="aiocqhttp.ws_reverse_host", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="WebSocket 反向连接 port", - description="", - value=config['aiocqhttp']['ws_reverse_port'], - path="aiocqhttp.ws_reverse_port", - ), - ] - ) - - general_platform_detail_group = DashBoardConfig( - config_type="group", - name="通用平台配置", - description="", - body=[ - DashBoardConfig( - config_type="item", - val_type="bool", - name="启动消息文字转图片", - description="启动后,机器人会将消息转换为图片发送,以降低风控风险。", - value=config['qq_pic_mode'], - path="qq_pic_mode", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="消息限制时间", - description="在此时间内,机器人不会回复同一个用户的消息。单位:秒", - value=config['limit']['time'], - path="limit.time", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="消息限制次数", - description="在上面的时间内,如果用户发送消息超过此次数,则机器人不会回复。单位:次", - value=config['limit']['count'], - path="limit.count", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="回复前缀", - description="[xxxx] 你好! 其中xxxx是你可以填写的前缀。如果为空则不显示。", - value=config['reply_prefix'], - path="reply_prefix", - ), - DashBoardConfig( - config_type="item", - val_type="list", - name="通用管理员用户 ID(支持多个管理员)。通过 !myid 指令获取。", - description="", - value=config['other_admins'], - path="other_admins", - ), - DashBoardConfig( - config_type="item", - val_type="bool", - name="独立会话", - description="是否启用独立会话模式,即 1 个用户自然账号 1 个会话。", - value=config['uniqueSessionMode'], - path="uniqueSessionMode", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="LLM 唤醒词", - description="如果不为空, 那么只有当消息以此词开头时,才会调用大语言模型进行回复。如设置为 /chat,那么只有当消息以 /chat 开头时,才会调用大语言模型进行回复。", - value=config['llm_wake_prefix'], - path="llm_wake_prefix", - ) - ] - ) - - openai_official_llm_group = DashBoardConfig( - config_type="group", - name="OpenAI 官方接口类设置", - description="", - body=[ - DashBoardConfig( - config_type="item", - val_type="list", - name="OpenAI API Key", - description="OpenAI API 的 Key。支持使用非官方但兼容的 API(第三方中转key)。", - value=config['openai']['key'], - path="openai.key", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="OpenAI API 节点地址(api base)", - description="OpenAI API 的节点地址,配合非官方 API 使用。如果不想填写,那么请填写 none", - value=config['openai']['api_base'], - path="openai.api_base", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="OpenAI model", - description="OpenAI LLM 模型。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['chatGPTConfigs']['model'], - path="openai.chatGPTConfigs.model", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="OpenAI max_tokens", - description="OpenAI 最大生成长度。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['chatGPTConfigs']['max_tokens'], - path="openai.chatGPTConfigs.max_tokens", - ), - DashBoardConfig( - config_type="item", - val_type="float", - name="OpenAI temperature", - description="OpenAI 温度。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['chatGPTConfigs']['temperature'], - path="openai.chatGPTConfigs.temperature", - ), - DashBoardConfig( - config_type="item", - val_type="float", - name="OpenAI top_p", - description="OpenAI top_p。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['chatGPTConfigs']['top_p'], - path="openai.chatGPTConfigs.top_p", - ), - DashBoardConfig( - config_type="item", - val_type="float", - name="OpenAI frequency_penalty", - description="OpenAI frequency_penalty。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['chatGPTConfigs']['frequency_penalty'], - path="openai.chatGPTConfigs.frequency_penalty", - ), - DashBoardConfig( - config_type="item", - val_type="float", - name="OpenAI presence_penalty", - description="OpenAI presence_penalty。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['chatGPTConfigs']['presence_penalty'], - path="openai.chatGPTConfigs.presence_penalty", - ), - DashBoardConfig( - config_type="item", - val_type="int", - name="OpenAI 总生成长度限制", - description="OpenAI 总生成长度限制。详见 https://platform.openai.com/docs/api-reference/chat", - value=config['openai']['total_tokens_limit'], - path="openai.total_tokens_limit", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="OpenAI 图像生成模型", - description="OpenAI 图像生成模型。", - value=config['openai_image_generate']['model'], - path="openai_image_generate.model", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="OpenAI 图像生成大小", - description="OpenAI 图像生成大小。", - value=config['openai_image_generate']['size'], - path="openai_image_generate.size", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="OpenAI 图像生成风格", - description="OpenAI 图像生成风格。修改前请参考 OpenAI 官方文档", - value=config['openai_image_generate']['style'], - path="openai_image_generate.style", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="OpenAI 图像生成质量", - description="OpenAI 图像生成质量。修改前请参考 OpenAI 官方文档", - value=config['openai_image_generate']['quality'], - path="openai_image_generate.quality", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="问题题首提示词", - description="如果填写了此项,在每个对大语言模型的请求中,都会在问题前加上此提示词。", - value=config['llm_env_prompt'], - path="llm_env_prompt", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="默认人格文本", - description="默认人格文本", - value=config['default_personality_str'], - path="default_personality_str", - ), - ] - ) - - baidu_aip_group = DashBoardConfig( - config_type="group", - name="百度内容审核", - description="需要去申请", - body=[ - DashBoardConfig( - config_type="item", - val_type="bool", - name="启动百度内容审核服务", - description="", - value=config['baidu_aip']['enable'], - path="baidu_aip.enable" - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="APP ID", - description="", - value=config['baidu_aip']['app_id'], - path="baidu_aip.app_id" - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="API KEY", - description="", - value=config['baidu_aip']['api_key'], - path="baidu_aip.api_key" - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="SECRET KEY", - description="", - value=config['baidu_aip']['secret_key'], - path="baidu_aip.secret_key" - ) - ] - ) - - other_group = DashBoardConfig( - config_type="group", - name="其他配置", - description="其他配置描述", - body=[ - DashBoardConfig( - config_type="item", - val_type="str", - name="HTTP 代理地址", - description="建议上下一致", - value=config['http_proxy'], - path="http_proxy", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="HTTPS 代理地址", - description="建议上下一致", - value=config['https_proxy'], - path="https_proxy", - ), - DashBoardConfig( - config_type="item", - val_type="str", - name="面板用户名", - description="是的,就是你理解的这个面板的用户名", - value=config['dashboard_username'], - path="dashboard_username", - ), - ] - ) - - dashboard_data.configs['data'] = [ - qq_official_platform_group, - qq_gocq_platform_group, - general_platform_detail_group, - openai_official_llm_group, - other_group, - baidu_aip_group, - qq_aiocqhttp_platform_group - ] - - except Exception as e: - logger.error(f"配置文件解析错误:{e}") - raise e - - def save_config(self, post_config: list, namespace: str): - ''' - 根据 path 解析并保存配置 - ''' - - queue = post_config - while len(queue) > 0: - config = queue.pop(0) - if config['config_type'] == "group": - for item in config['body']: - queue.append(item) - elif config['config_type'] == "item": - if config['path'] is None or config['path'] == "": + def validate_config(self, data): + errors = [] + # 递归验证数据 + def validate(data, path=""): + for key, meta in CONFIG_METADATA_2.items(): + if key not in data: + if key not in self.config_key_dont_show: + # 这些key不会传给前端,所以不需要验证 + errors.append(f"Missing key: {path}{key}") continue + value = data[key] + if meta["type"] == "int" and not isinstance(value, int): + errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}") + elif meta["type"] == "bool" and not isinstance(value, bool): + errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}") + elif meta["type"] == "string" and not isinstance(value, str): + errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}") + elif meta["type"] == "list" and not isinstance(value, list): + errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}") + for item in value: + validate(item, meta["items"], path=f"{path}{key}.") + elif meta["type"] == "dict" and not isinstance(value, dict): + errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}") + validate(value, meta["items"], path=f"{path}{key}.") + validate(data) + + # hardcode warning + data['config_version'] = self.context.config_helper.config_version + data['dashboard'] = asdict(self.context.config_helper.dashboard) + + return errors - path = config['path'].split('.') - if len(path) == 0: - continue - - if config['val_type'] == "bool": - self._write_config( - namespace, config['path'], config['value']) - elif config['val_type'] == "str": - self._write_config( - namespace, config['path'], config['value']) - elif config['val_type'] == "int": - try: - self._write_config( - namespace, config['path'], int(config['value'])) - except: - raise ValueError(f"配置项 {config['name']} 的值必须是整数") - elif config['val_type'] == "float": - try: - self._write_config( - namespace, config['path'], float(config['value'])) - except: - raise ValueError(f"配置项 {config['name']} 的值必须是浮点数") - elif config['val_type'] == "list": - if config['value'] is None: - self._write_config(namespace, config['path'], []) - elif not isinstance(config['value'], list): - raise ValueError(f"配置项 {config['name']} 的值必须是列表") - self._write_config( - namespace, config['path'], config['value']) - else: - raise NotImplementedError( - f"未知或者未实现的配置项类型:{config['val_type']}") - - def _write_config(self, namespace: str, key: str, value): - if namespace == "" or namespace.startswith("internal_"): - # 机器人自带配置,存到 config.yaml - self.context.config_helper.put_by_dot_str(key, value) - else: - update_config(namespace, key, value) + def save_astrbot_config(self, post_config: dict): + '''验证并保存配置''' + errors = self.validate_config(post_config) + if errors: + raise ValueError(f"格式校验未通过: {errors}") + self.context.config_helper.flush_config(post_config) + + def save_extension_config(self, post_config: list, namespace: str): + pass + # update_config(namespace, key, value) \ No newline at end of file diff --git a/dashboard/server.py b/dashboard/server.py index e83a37161..074e08776 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -19,6 +19,7 @@ from dashboard.helper import DashBoardHelper from util.io import get_local_ip_addresses from model.plugin.manager import PluginManager from util.updator.astrbot_updator import AstrBotUpdator +from type.config import CONFIG_METADATA_2 logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -30,7 +31,7 @@ class AstrBotDashBoard(): self.plugin_manager = plugin_manager self.astrbot_updator = astrbot_updator self.dashboard_data = DashBoardData() - self.dashboard_helper = DashBoardHelper(self.context, self.dashboard_data) + self.dashboard_helper = DashBoardHelper(self.context) self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/") logging.getLogger('werkzeug').setLevel(logging.ERROR) @@ -68,8 +69,8 @@ class AstrBotDashBoard(): @self.dashboard_be.post("/api/authenticate") def authenticate(): - username = self.context.base_config.get("dashboard_username", "") - password = self.context.base_config.get("dashboard_password", "") + username = self.context.config_helper.dashboard.username + password = self.context.config_helper.dashboard.password # 获得请求体 post_data = request.json if post_data["username"] == username and post_data["password"] == password: @@ -90,7 +91,7 @@ class AstrBotDashBoard(): @self.dashboard_be.post("/api/change_password") def change_password(): - password = self.context.base_config.get("dashboard_password", "") + password = self.context.config_helper.dashboard.password # 获得请求体 post_data = request.json if post_data["password"] == password: @@ -130,40 +131,53 @@ class AstrBotDashBoard(): @self.dashboard_be.get("/api/configs") def get_configs(): - # 如果params中有namespace,则返回该namespace下的配置 - # 否则返回所有配置 + # namespace 为空时返回 AstrBot 配置 + # 否则返回指定 namespace 的插件配置 namespace = "" if "namespace" not in request.args else request.args["namespace"] - conf = self._get_configs(namespace) - return Response( - status="success", - message="", - data=conf - ).__dict__ - - @self.dashboard_be.get("/api/config_outline") - def get_config_outline(): - outline = self._generate_outline() - return Response( - status="success", - message="", - data=outline - ).__dict__ - - @self.dashboard_be.post("/api/configs") - def post_configs(): - post_configs = request.json - try: - self.on_post_configs(post_configs) + if not namespace: return Response( status="success", - message="保存成功~ 机器人将在 2 秒内重启以应用新的配置。", + message="", + data=self._get_astrbot_config() + ).__dict__ + return Response( + status="success", + message="", + data=self._get_extension_config(namespace) + ).__dict__ + + @self.dashboard_be.post("/api/astrbot-configs") + def post_astrbot_configs(): + post_configs = request.json + try: + self.save_astrbot_configs(post_configs) + return Response( + status="success", + message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。", data=None ).__dict__ except Exception as e: return Response( status="error", message=e.__str__(), - data=self.dashboard_data.configs + data=None + ).__dict__ + + @self.dashboard_be.post("/api/extension-configs") + def post_extension_configs(): + post_configs = request.json + try: + self.save_extension_configs(post_configs) + return Response( + status="success", + message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。", + data=None + ).__dict__ + except Exception as e: + return Response( + status="error", + message=e.__str__(), + data=None ).__dict__ @self.dashboard_be.get("/api/extensions") @@ -363,47 +377,41 @@ class AstrBotDashBoard(): data=None ).__dict__ - def on_post_configs(self, post_configs: dict): + def save_astrbot_configs(self, post_configs: dict): try: - if 'base_config' in post_configs: - self.dashboard_helper.save_config( - post_configs['base_config'], namespace='') # 基础配置 - self.dashboard_helper.save_config( - post_configs['config'], namespace=post_configs['namespace']) # 选定配置 - self.dashboard_helper.parse_default_config( - self.dashboard_data, self.context.config_helper.get_all()) - # 重启 - threading.Thread(target=self.astrbot_updator._reboot, - args=(2, ), daemon=True).start() + self.dashboard_helper.save_astrbot_config(post_configs) + threading.Thread(target=self.astrbot_updator._reboot, args=(3, ), daemon=True).start() except Exception as e: raise e + + def save_extension_configs(self, post_configs: dict): + try: + self.dashboard_helper.save_extension_config(post_configs) + threading.Thread(target=self.astrbot_updator._reboot, args=(3, ), daemon=True).start() + except Exception as e: + raise e + + def _get_astrbot_config(self): + config = self.context.config_helper.to_dict() + for key in self.dashboard_helper.config_key_dont_show: + if key in config: + del config[key] + return { + "metadata": CONFIG_METADATA_2, + "config": config, + } - def _get_configs(self, namespace: str): - if namespace == "": - ret = [self.dashboard_data.configs['data'][4], - self.dashboard_data.configs['data'][5],] - elif namespace == "internal_platform_qq_official": - ret = [self.dashboard_data.configs['data'][0],] - elif namespace == "internal_platform_qq_gocq": - ret = [self.dashboard_data.configs['data'][1],] - elif namespace == "internal_platform_general": # 全局平台配置 - ret = [self.dashboard_data.configs['data'][2],] - elif namespace == "internal_llm_openai_official": - ret = [self.dashboard_data.configs['data'][3],] - elif namespace == "internal_platform_qq_aiocqhttp": - ret = [self.dashboard_data.configs['data'][6],] - else: - path = f"data/config/{namespace}.json" - if not os.path.exists(path): - return [] - with open(path, "r", encoding="utf-8-sig") as f: - ret = [{ - "config_type": "group", - "name": namespace + " 插件配置", - "description": "", - "body": list(json.load(f).values()) - },] - return ret + def _get_extension_config(self, namespace: str): + path = f"data/config/{namespace}.json" + if not os.path.exists(path): + return [] + with open(path, "r", encoding="utf-8-sig") as f: + return [{ + "config_type": "group", + "name": namespace + " 插件配置", + "description": "", + "body": list(json.load(f).values()) + },] def _generate_outline(self): ''' diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 377cd9df8..a8b97de25 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -62,12 +62,11 @@ class InternalCommandHandler: return CommandResult().message("你没有权限使用该指令。") l = message_str.split(" ") if len(l) == 1: - return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}") + return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.config_helper.wake_prefix}") nick = l[1].strip() if not nick: return CommandResult().message("wake: 请指定唤醒词。") - context.config_helper.put("nick_qq", nick) - context.nick = tuple(nick) + context.config_helper.wake_prefix = [nick] return CommandResult( hit=True, success=True, @@ -232,17 +231,17 @@ class InternalCommandHandler: ) def t2i_toggle(self, message: AstrMessageEvent, context: Context): - p = context.t2i_mode + p = context.config_helper.t2i if p: context.config_helper.put("qq_pic_mode", False) - context.t2i_mode = False + context.config_helper.t2i = False return CommandResult( hit=True, success=True, message_chain="已关闭文本转图片模式。", ) context.config_helper.put("qq_pic_mode", True) - context.t2i_mode = True + context.config_helper.t2i = True return CommandResult( hit=True, diff --git a/model/platform/__init__.py b/model/platform/__init__.py index 7eadf4270..56ea9ded0 100644 --- a/model/platform/__init__.py +++ b/model/platform/__init__.py @@ -52,10 +52,11 @@ class Platform(): return ret[:100] if len(ret) > 100 else ret def check_nick(self, message_str: str) -> bool: - if self.context.nick: - for nick in self.context.nick: - if nick and message_str.strip().startswith(nick): - return True + w = self.context.config_helper.wake_prefix + if not w: return False + for nick in w: + if nick and message_str.strip().startswith(nick): + return True return False async def convert_to_t2i_chain(self, message_result: list) -> list: diff --git a/model/platform/manager.py b/model/platform/manager.py index 5ca217346..028a91e8f 100644 --- a/model/platform/manager.py +++ b/model/platform/manager.py @@ -6,6 +6,12 @@ from type.types import Context from SparkleLogging.utils.core import LogManager from logging import Logger from astrbot.message.handler import MessageHandler +from util.cmd_config import ( + PlatformConfig, + AiocqhttpPlatformConfig, + NakuruPlatformConfig, + QQOfficialPlatformConfig +) logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -13,36 +19,40 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot') class PlatformManager(): def __init__(self, context: Context, message_handler: MessageHandler) -> None: self.context = context - self.config = context.base_config self.msg_handler = message_handler def load_platforms(self): tasks = [] - if 'gocqbot' in self.config and self.config['gocqbot']['enable']: - logger.info("启用 QQ(nakuru 适配器)") - tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter")) - - if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']: - logger.info("启用 QQ(aiocqhttp 适配器)") - tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter")) + platforms = self.context.config_helper.platform + logger.info(f"加载 {len(platforms)} 个机器人消息平台...") + for platform in platforms: + if not platform.enable: + continue + if platform.name == "qq_official": + assert isinstance(platform, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。" + logger.info(f"加载 QQ官方 机器人消息平台 (appid: {platform.appid})") + tasks.append(asyncio.create_task(self.qqofficial_bot(platform), name="qqofficial-adapter")) + elif platform.name == "nakuru": + assert isinstance(platform, NakuruPlatformConfig), "nakuru: 无法识别的配置类型。" + logger.info(f"加载 QQ(nakuru) 机器人消息平台 ({platform.host}, {platform.websocket_port}, {platform.port})") + tasks.append(asyncio.create_task(self.nakuru_bot(platform), name="nakuru-adapter")) + elif platform.name == "aiocqhttp": + assert isinstance(platform, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。" + logger.info("加载 QQ(aiocqhttp) 机器人消息平台") + tasks.append(asyncio.create_task(self.aiocq_bot(platform), name="aiocqhttp-adapter")) - # QQ频道 - if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None: - logger.info("启用 QQ(官方 API) 机器人消息平台") - tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter")) - return tasks - async def gocq_bot(self): + async def nakuru_bot(self, config: NakuruPlatformConfig): ''' 运行 QQ(nakuru 适配器) ''' - from model.platform.qq_nakuru import QQGOCQ + from model.platform.qq_nakuru import QQNakuru noticed = False - host = self.config.get("gocq_host", "127.0.0.1") - port = self.config.get("gocq_websocket_port", 6700) - http_port = self.config.get("gocq_http_port", 5700) + host = config.host + port = config.websocket_port + http_port = config.port logger.info( f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}") while True: @@ -56,32 +66,32 @@ class PlatformManager(): logger.info("nakuru 适配器已连接。") break try: - qq_gocq = QQGOCQ(self.context, self.msg_handler) + qq_gocq = QQNakuru(self.context, self.msg_handler, config) self.context.platforms.append(RegisteredPlatform( - platform_name="gocq", platform_instance=qq_gocq, origin="internal")) + platform_name="nakuru", platform_instance=qq_gocq, origin="internal")) await qq_gocq.run() except BaseException as e: logger.error("启动 nakuru 适配器时出现错误: " + str(e)) - def aiocq_bot(self): + def aiocq_bot(self, config): ''' 运行 QQ(aiocqhttp 适配器) ''' from model.platform.qq_aiocqhttp import AIOCQHTTP - qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler) + qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler, config) self.context.platforms.append(RegisteredPlatform( platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal")) return qq_aiocqhttp.run_aiocqhttp() - def qqchan_bot(self): + def qqofficial_bot(self, config): ''' 运行 QQ 官方机器人适配器 ''' try: from model.platform.qq_official import QQOfficial - qqchannel_bot = QQOfficial(self.context, self.msg_handler) + qqchannel_bot = QQOfficial(self.context, self.msg_handler, config) self.context.platforms.append(RegisteredPlatform( - platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal")) + platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal")) return qqchannel_bot.run() except BaseException as e: logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e)) diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index 0788a5b2e..3f03c6cf4 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -13,18 +13,25 @@ from nakuru.entities.components import * from SparkleLogging.utils.core import LogManager from logging import Logger from astrbot.message.handler import MessageHandler +from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig logger: Logger = LogManager.GetLogger(log_name='astrbot') class AIOCQHTTP(Platform): - def __init__(self, context: Context, message_handler: MessageHandler) -> None: + def __init__(self, context: Context, + message_handler: MessageHandler, + platform_config: PlatformConfig) -> None: + assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。" + self.message_handler = message_handler self.waiting = {} self.context = context - self.unique_session = self.context.unique_session - self.announcement = self.context.base_config.get("announcement", "欢迎新人!") - self.host = self.context.base_config['aiocqhttp']['ws_reverse_host'] - self.port = self.context.base_config['aiocqhttp']['ws_reverse_port'] + self.config = platform_config + self.unique_session = context.config_helper.platform_settings.unique_session + self.announcement = context.config_helper.platform_settings.welcome_message_when_join + self.host = platform_config.ws_reverse_host + self.port = platform_config.ws_reverse_port + self.admins = context.config_helper.admins_id def convert_message(self, event: Event) -> AstrBotMessage: @@ -123,12 +130,11 @@ class AIOCQHTTP(Platform): # 解析 role sender_id = str(message.sender.user_id) - if sender_id == self.context.base_config.get('admin_qq', '') or \ - sender_id in self.context.base_config.get('other_admins', []): + if sender_id in self.admins: role = 'admin' else: role = 'member' - + # construct astrbot message event ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role) @@ -160,7 +166,7 @@ class AIOCQHTTP(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): + if self.context.config_helper.t2i and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: diff --git a/model/platform/qq_nakuru.py b/model/platform/qq_nakuru.py index d4052094d..0e586f1f2 100644 --- a/model/platform/qq_nakuru.py +++ b/model/platform/qq_nakuru.py @@ -18,6 +18,7 @@ from type.command import * from SparkleLogging.utils.core import LogManager from logging import Logger from astrbot.message.handler import MessageHandler +from util.cmd_config import PlatformConfig, NakuruPlatformConfig logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -28,46 +29,52 @@ class FakeSource: self.group_id = group_id -class QQGOCQ(Platform): - def __init__(self, context: Context, message_handler: MessageHandler) -> None: +class QQNakuru(Platform): + def __init__(self, context: Context, + message_handler: MessageHandler, + platform_config: PlatformConfig) -> None: + assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。" + self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.message_handler = message_handler self.waiting = {} self.context = context - self.unique_session = self.context.unique_session - self.announcement = self.context.base_config.get("announcement", "欢迎新人!") - + self.unique_session = context.config_helper.platform_settings.unique_session + self.announcement = context.config_helper.platform_settings.welcome_message_when_join + self.config = platform_config + self.admins = context.config_helper.admins_id + self.client = CQHTTP( - host=self.context.base_config.get("gocq_host", "127.0.0.1"), - port=self.context.base_config.get("gocq_websocket_port", 6700), - http_port=self.context.base_config.get("gocq_http_port", 5700), + host=self.config.host, + port=self.config.websocket_port, + http_port=self.config.port ) gocq_app = self.client @gocq_app.receiver("GroupMessage") async def _(app: CQHTTP, source: GroupMessage): - if self.context.base_config.get("gocq_react_group", True): + if self.config.enable_group: abm = self.convert_message(source) await self.handle_msg(abm) @gocq_app.receiver("FriendMessage") async def _(app: CQHTTP, source: FriendMessage): - if self.context.base_config.get("gocq_react_friend", True): + if self.config.enable_direct_message: abm = self.convert_message(source) await self.handle_msg(abm) @gocq_app.receiver("GroupMemberIncrease") async def _(app: CQHTTP, source: GroupMemberIncrease): - if self.context.base_config.get("gocq_react_group_increase", True): + if self.config.enable_group_increase: await app.sendGroupMessage(source.group_id, [ Plain(text=self.announcement) ]) @gocq_app.receiver("GuildMessage") async def _(app: CQHTTP, source: GuildMessage): - if self.cc.get("gocq_react_guild", True): + if self.config.enable_guild: abm = self.convert_message(source) await self.handle_msg(abm) @@ -112,8 +119,7 @@ class QQGOCQ(Platform): # 解析 role sender_id = str(message.raw_message.user_id) - if sender_id == self.context.base_config.get('admin_qq', '') or \ - sender_id in self.context.base_config.get('other_admins', []): + if sender_id in self.admins: role = 'admin' else: role = 'member' @@ -152,7 +158,7 @@ class QQGOCQ(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list): + if self.context.config_helper.t2i and isinstance(res, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: @@ -186,7 +192,7 @@ class QQGOCQ(Platform): plain_text_len += len(i.text) elif isinstance(i, Image): image_num += 1 - if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200): + if plain_text_len > self.context.config_helper.platform_settings.forward_threshold or image_num > 1: # 删除At for i in message_chain: if isinstance(i, At): diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index 5ca3e301b..297ceb4b8 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -19,6 +19,7 @@ from nakuru.entities.components import * from SparkleLogging.utils.core import LogManager from logging import Logger from astrbot.message.handler import MessageHandler +from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -52,32 +53,36 @@ class botClient(Client): class QQOfficial(Platform): - def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None: - super().__init__() + def __init__(self, context: Context, + message_handler: MessageHandler, + platform_config: PlatformConfig, + test_mode = False) -> None: + assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。" self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.message_handler = message_handler self.waiting: dict = {} self.context = context + self.config = platform_config + self.admins = context.config_helper.admins_id - self.appid = context.base_config['qqbot']['appid'] - self.token = context.base_config['qqbot']['token'] - self.secret = context.base_config['qqbot_secret'] - self.unique_session = context.unique_session - qq_group = context.base_config['qqofficial_enable_group_message'] - + self.appid = platform_config.appid + self.secret = platform_config.secret + self.unique_session = context.config_helper.platform_settings.unique_session + qq_group = platform_config.enable_group_c2c + guild_dm = platform_config.enable_guild_direct_message if qq_group: self.intents = botpy.Intents( public_messages=True, public_guild_messages=True, - direct_message=context.base_config['direct_message_mode'] + direct_message=guild_dm ) else: self.intents = botpy.Intents( public_guild_messages=True, - direct_message=context.base_config['direct_message_mode'] + direct_message=guild_dm ) self.client = botClient( intents=self.intents, @@ -168,24 +173,11 @@ class QQOfficial(Platform): return abm def run(self): - try: - return self.client.start( - appid=self.appid, - secret=self.secret - ) - except BaseException as e: - # 早期的 qq-botpy 版本使用 token 登录。 - logger.error(traceback.format_exc()) - self.client = botClient( - intents=self.intents, - bot_log=False - ) - self.client.set_platform(self) - return self.client.start( - appid=self.appid, - token=self.token - ) - + return self.client.start( + appid=self.appid, + secret=self.secret + ) + async def handle_msg(self, message: AstrBotMessage): assert isinstance(message.raw_message, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage)) @@ -209,8 +201,7 @@ class QQOfficial(Platform): # 解析出 role sender_id = message.sender.user_id - if sender_id == self.context.base_config.get('admin_qqchan', None) or \ - sender_id in self.context.base_config.get('other_admins', None): + if sender_id in self.admins: role = 'admin' else: role = 'member' @@ -249,7 +240,7 @@ class QQOfficial(Platform): msg_ref = None rendered_images = [] - if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list): + if self.context.config_helper.t2i and isinstance(result_message, list): rendered_images = await self.convert_to_t2i_chain(result_message) if isinstance(result_message, list): diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index b5eefcd23..69a8cee00 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -1,5 +1,3 @@ -import os -import sys import json import time import tiktoken @@ -15,12 +13,12 @@ from openai._exceptions import * from astrbot.persist.helper import dbConn from model.provider.provider import Provider from util import general_utils as gu -from util.cmd_config import CmdConfig +from util.cmd_config import LLMConfig from SparkleLogging.utils.core import LogManager from logging import Logger from typing import List, Dict -from type.types import Context +from dataclasses import asdict logger: Logger = LogManager.GetLogger(log_name='astrbot') @@ -48,22 +46,16 @@ MODELS = { } class ProviderOpenAIOfficial(Provider): - def __init__(self, context: Context) -> None: + def __init__(self, llm_config: LLMConfig) -> None: super().__init__() - os.makedirs("data/openai", exist_ok=True) - - self.context = context - self.key_data_path = "data/openai/keys.json" self.api_keys = [] self.chosen_api_key = None self.base_url = None + self.llm_config = llm_config self.keys_data = {} # 记录超额 - - cfg = context.base_config['openai'] - - if cfg['key']: self.api_keys = cfg['key'] - if cfg['api_base']: self.base_url = cfg['api_base'] + if llm_config.key: self.api_keys = llm_config.key + if llm_config.api_base: self.base_url = llm_config.api_base if not self.api_keys: logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。") else: @@ -76,18 +68,20 @@ class ProviderOpenAIOfficial(Provider): api_key=self.chosen_api_key, base_url=self.base_url ) - self.model_configs: Dict = cfg['chatGPTConfigs'] - super().set_curr_model(self.model_configs['model']) - self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None) + super().set_curr_model(llm_config.model_config.model) + self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config) self.session_memory: Dict[str, List] = {} # 会话记忆 self.session_memory_lock = threading.Lock() - self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小 + self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小 logger.info("正在载入分词器 cl100k_base...") self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器 logger.info("分词器载入完成。") - self.DEFAULT_PERSONALITY = context.default_personality + self.DEFAULT_PERSONALITY = { + "prompt": self.llm_config.default_personality, + "name": "default" + } self.curr_personality = self.DEFAULT_PERSONALITY self.session_personality = {} # 记录了某个session是否已设置人格。 # 从 SQLite DB 读取历史记录 @@ -134,7 +128,7 @@ class ProviderOpenAIOfficial(Provider): self.session_personality = {} # 重置 encoded_prompt = self.tokenizer.encode(default_personality['prompt']) tokens_num = len(encoded_prompt) - model = self.model_configs['model'] + model = self.get_curr_model() if model in MODELS and tokens_num > MODELS[model] - 500: default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500]) @@ -173,7 +167,7 @@ class ProviderOpenAIOfficial(Provider): for record in self.session_memory[session_id]: if "user" in record and record['user']: if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list): - logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。") + logger.warn(f"由于当前模型 {self.get_curr_model()} 不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。") continue context.append(record['user']) if "AI" in record and record['AI']: @@ -185,7 +179,7 @@ class ProviderOpenAIOfficial(Provider): ''' 是否是 LVM ''' - return self.model_configs['model'].startswith("gpt-4") + return self.get_curr_model().startswith("gpt-4") async def get_models(self): try: @@ -238,7 +232,7 @@ class ProviderOpenAIOfficial(Provider): self.session_memory[session_id].append(message) # 根据 模型的上下文窗口 淘汰掉多余的记录 - curr_model = self.model_configs['model'] + curr_model = self.get_curr_model() if curr_model in MODELS: maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion # if message['usage_tokens'] > maxium_tokens_num: @@ -314,7 +308,7 @@ class ProviderOpenAIOfficial(Provider): # 1. 可以保证之后 pop 的时候不会出现问题 # 2. 可以保证不会超过最大 token 数 _encoded_prompt = self.tokenizer.encode(prompt) - curr_model = self.model_configs['model'] + curr_model = self.get_curr_model() if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300: _encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300] prompt = self.tokenizer.decode(_encoded_prompt) @@ -325,7 +319,7 @@ class ProviderOpenAIOfficial(Provider): # 获取上下文,openai 格式 contexts = await self.retrieve_context(session_id) - conf = self.model_configs + conf = asdict(self.llm_config.model_config) if extra_conf: conf.update(extra_conf) # start request @@ -358,7 +352,7 @@ class ProviderOpenAIOfficial(Provider): except BadRequestError as e: logger.warn(f"OpenAI 请求异常:{e}。") if "image_url is only supported by certain models." in str(e): - raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。") + raise Exception(f"当前模型 { self.get_curr_model() } 不支持图片输入,请更换模型。") retry += 1 except RateLimitError as e: if "You exceeded your current quota" in str(e): @@ -493,12 +487,11 @@ class ProviderOpenAIOfficial(Provider): return contexts_str, len(self.session_memory[session_id]) def set_model(self, model: str): - self.model_configs['model'] = model - self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model) + # TODO: 更新配置文件 super().set_curr_model(model) def get_configs(self): - return self.model_configs + return asdict(self.llm_config) def get_keys_data(self): return self.keys_data diff --git a/type/config.py b/type/config.py index e13bc9262..3a44a0b2b 100644 --- a/type/config.py +++ b/type/config.py @@ -72,4 +72,281 @@ DEFAULT_CONFIG = { "ws_reverse_host": "", "ws_reverse_port": 0, } +} + +# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D +DEFAULT_CONFIG_VERSION_2 = { + "config_version": 2, + "platform": [ + { + "name": "qq_official", + "enable": False, + "appid": "", + "secret": "", + "enable_group_c2c": True, + "enable_guild_direct_message": True, + }, + { + "name": "nakuru", + "enable": False, + "host": "172.0.0.1", + "port": 5700, + "websocket_port": 6700, + "enable_group": True, + "enable_guild": True, + "enable_direct_message": True, + "enable_group_increase": True, + }, + { + "name": "aiocqhttp", + "enable": False, + "ws_reverse_host": "", + "ws_reverse_port": 6199, + } + ], + "platform_settings": { + "unique_session": False, + "welcome_message_when_join": "", + "rate_limit": { + "time": 60, + "count": 30, + }, + "reply_prefix": "", + "forward_threshold": 200, # 转发消息的阈值 + }, + "llm": [ + { + "name": "openai", + "enable": True, + "key": [], + "api_base": "", + "prompt_prefix": "", + "default_personality": "", + "model_config": { + "model": "gpt-4o", + "max_tokens": 6000, + "temperature": 0.9, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + }, + "image_generation_model_config": { + "enable": True, + "model": "dall-e-3", + "size": "1024x1024", + "style": "vivid", + "quality": "standard", + } + }, + ], + "llm_settings": { + "wake_prefix": "", + "web_search": False, + }, + "content_safety": { + "baidu_aip": { + "enable": False, + "app_id": "", + "api_key": "", + "secret_key": "", + }, + "internal_keywords": { + "enable": True, + "extra_keywords": [], + } + }, + "wake_prefix": [], + "t2i": True, + "dump_history_interval": 10, + "admins_id": [], + "https_proxy": "", + "http_proxy": "", + "dashboard": { + "enable": True, + "username": "", + "password": "", + }, +} + +# 这个是用于迁移旧版本配置文件的映射表 +MAPPINGS_1_2 = [ + [["qqbot", "enable"], ["platform", 0, "enable"]], + [["qqbot", "appid"], ["platform", 0, "appid"]], + [["qqbot", "token"], ["platform", 0, "secret"]], + [["qqofficial_enable_group_message"], ["platform", 0, "enable_group_c2c"]], + [["direct_message_mode"], ["platform", 0, "enable_guild_direct_message"]], + [["gocqbot", "enable"], ["platform", 1, "enable"]], + [["gocq_host"], ["platform", 1, "host"]], + [["gocq_http_port"], ["platform", 1, "port"]], + [["gocq_websocket_port"], ["platform", 1, "websocket_port"]], + [["gocq_react_group"], ["platform", 1, "enable_group"]], + [["gocq_react_guild"], ["platform", 1, "enable_guild"]], + [["gocq_react_friend"], ["platform", 1, "enable_direct_message"]], + [["gocq_react_group_increase"], ["platform", 1, "enable_group_increase"]], + [["aiocqhttp", "enable"], ["platform", 2, "enable"]], + [["aiocqhttp", "ws_reverse_host"], ["platform", 2, "ws_reverse_host"]], + [["aiocqhttp", "ws_reverse_port"], ["platform", 2, "ws_reverse_port"]], + [["uniqueSessionMode"], ["platform_settings", "unique_session"]], + [["qq_welcome"], ["platform_settings", "welcome_message_when_join"]], + [["limit", "time"], ["platform_settings", "rate_limit", "time"]], + [["limit", "count"], ["platform_settings", "rate_limit", "count"]], + [["reply_prefix"], ["platform_settings", "reply_prefix"]], + [["qq_forward_threshold"], ["platform_settings", "forward_threshold"]], + + [["openai", "key"], ["llm", 0, "key"]], + [["openai", "api_base"], ["llm", 0, "api_base"]], + [["openai", "chatGPTConfigs", "model"], ["llm", 0, "model_config", "model"]], + [["openai", "chatGPTConfigs", "max_tokens"], ["llm", 0, "model_config", "max_tokens"]], + [["openai", "chatGPTConfigs", "temperature"], ["llm", 0, "model_config", "temperature"]], + [["openai", "chatGPTConfigs", "top_p"], ["llm", 0, "model_config", "top_p"]], + [["openai", "chatGPTConfigs", "frequency_penalty"], ["llm", 0, "model_config", "frequency_penalty"]], + [["openai", "chatGPTConfigs", "presence_penalty"], ["llm", 0, "model_config", "presence_penalty"]], + + [["default_personality_str"], ["llm", 0, "default_personality"]], + [["llm_env_prompt"], ["llm", 0, "prompt_prefix"]], + [["openai_image_generate", "model"], ["llm", 0, "image_generation_model_config", "model"]], + [["openai_image_generate", "size"], ["llm", 0, "image_generation_model_config", "size"]], + [["openai_image_generate", "style"], ["llm", 0, "image_generation_model_config", "style"]], + [["openai_image_generate", "quality"], ["llm", 0, "image_generation_model_config", "quality"]], + + [["llm_wake_prefix"], ["llm_settings", "wake_prefix"]], + + [["baidu_aip", "enable"], ["content_safety", "baidu_aip", "enable"]], + [["baidu_aip", "app_id"], ["content_safety", "baidu_aip", "app_id"]], + [["baidu_aip", "api_key"], ["content_safety", "baidu_aip", "api_key"]], + [["baidu_aip", "secret_key"], ["content_safety", "baidu_aip", "secret_key"]], + + [["qq_pic_mode"], ["t2i"]], + [["dump_history_interval"], ["dump_history_interval"]], + [["other_admins"], ["admins_id"]], + [["http_proxy"], ["http_proxy"]], + [["https_proxy"], ["https_proxy"]], + [["dashboard_username"], ["dashboard", "username"]], + [["dashboard_password"], ["dashboard", "password"]], + [["nick_qq"], ["wake_prefix"]], +] + +CONFIG_METADATA_2 = { + "config_version": {"description": "配置版本", "type": "int"}, + "platform": { + "description": "平台配置", + "type": "list", + "items": { + "name": {"description": "平台名称", "type": "string"}, + "enable": {"description": "是否启用", "type": "bool"}, + "appid": {"description": "应用ID", "type": "string"}, + "secret": {"description": "应用密钥", "type": "string"}, + "enable_group_c2c": {"description": "启用群C2C", "type": "bool"}, + "enable_guild_direct_message": {"description": "启用公会直接消息", "type": "bool"}, + "host": {"description": "主机地址", "type": "string"}, + "port": {"description": "端口", "type": "int"}, + "websocket_port": {"description": "WebSocket端口", "type": "int"}, + "ws_reverse_host": {"description": "WebSocket反向主机", "type": "string"}, + "ws_reverse_port": {"description": "WebSocket反向端口", "type": "int"}, + "enable_group": {"description": "启用群组", "type": "bool"}, + "enable_guild": {"description": "启用公会", "type": "bool"}, + "enable_direct_message": {"description": "启用直接消息", "type": "bool"}, + "enable_group_increase": {"description": "启用群组增加", "type": "bool"}, + } + }, + "platform_settings": { + "description": "平台设置", + "type": "object", + "items": { + "unique_session": {"description": "唯一会话", "type": "bool"}, + "welcome_message_when_join": {"description": "加入时欢迎信息", "type": "string"}, + "rate_limit": { + "description": "速率限制", + "type": "object", + "items": { + "time": {"description": "时间", "type": "int"}, + "count": {"description": "计数", "type": "int"}, + } + }, + "reply_prefix": {"description": "回复前缀", "type": "string"}, + "forward_threshold": {"description": "转发消息的阈值", "type": "int"}, + } + }, + "llm": { + "description": "大语言模型配置", + "type": "list", + "items": { + "name": {"description": "模型名称", "type": "string"}, + "enable": {"description": "是否启用", "type": "bool"}, + "key": {"description": "密钥", "type": "list", "items": {"type": "string"}}, + "api_base": {"description": "API基础URL", "type": "string"}, + "prompt_prefix": {"description": "提示前缀", "type": "string"}, + "default_personality": {"description": "默认个性", "type": "string"}, + "model_config": { + "description": "模型配置", + "type": "object", + "items": { + "model": {"description": "模型名称", "type": "string"}, + "max_tokens": {"description": "最大令牌数", "type": "int"}, + "temperature": {"description": "温度", "type": "float"}, + "top_p": {"description": "Top P值", "type": "float"}, + "frequency_penalty": {"description": "频率惩罚", "type": "float"}, + "presence_penalty": {"description": "存在惩罚", "type": "float"}, + } + }, + "image_generation_model_config": { + "description": "图像生成模型配置", + "type": "object", + "items": { + "enable": {"description": "是否启用", "type": "bool"}, + "model": {"description": "模型名称", "type": "string"}, + "size": {"description": "图像尺寸", "type": "string"}, + "style": {"description": "图像风格", "type": "string"}, + "quality": {"description": "图像质量", "type": "string"}, + } + }, + } + }, + "llm_settings": { + "description": "大语言模型设置", + "type": "object", + "items": { + "wake_prefix": {"description": "唤醒前缀", "type": "string"}, + "web_search": {"description": "启用网络搜索", "type": "bool"}, + } + }, + "content_safety": { + "description": "内容安全配置", + "type": "object", + "items": { + "baidu_aip": { + "description": "百度AI平台配置", + "type": "object", + "items": { + "enable": {"description": "是否启用", "type": "bool"}, + "app_id": {"description": "应用ID", "type": "string"}, + "api_key": {"description": "API密钥", "type": "string"}, + "secret_key": {"description": "秘密密钥", "type": "string"}, + } + }, + "internal_keywords": { + "description": "内部关键词过滤", + "type": "object", + "items": { + "enable": {"description": "是否启用", "type": "bool"}, + "extra_keywords": {"description": "额外关键词", "type": "list", "items": {"type": "string"}}, + } + } + } + }, + "wake_prefix": {"description": "唤醒前缀列表", "type": "list", "items": {"type": "string"}}, + "t2i": {"description": "文本转图像功能", "type": "bool"}, + "dump_history_interval": {"description": "历史记录转储间隔", "type": "int"}, + "admins_id": {"description": "管理员ID列表", "type": "list", "items": {"type": "int"}}, + "https_proxy": {"description": "HTTPS代理", "type": "string"}, + "http_proxy": {"description": "HTTP代理", "type": "string"}, + "dashboard": { + "description": "仪表盘配置", + "type": "object", + "items": { + "enable": {"description": "是否启用", "type": "bool"}, + "username": {"description": "用户名", "type": "string"}, + "password": {"description": "密码", "type": "string"}, + } + }, } \ No newline at end of file diff --git a/type/types.py b/type/types.py index 195e2b1a6..9f9035ca1 100644 --- a/type/types.py +++ b/type/types.py @@ -3,7 +3,7 @@ from asyncio import Task from type.register import * from typing import List, Awaitable from logging import Logger -from util.cmd_config import CmdConfig +from util.cmd_config import AstrBotConfig from util.t2i.renderer import TextToImageRenderer from util.updator.astrbot_updator import AstrBotUpdator from util.image_uploader import ImageUploader @@ -20,17 +20,17 @@ class Context: def __init__(self): self.logger: Logger = None self.base_config: dict = None # 配置(期望启动机器人后是不变的) - self.config_helper: CmdConfig = None + self.config_helper: AstrBotConfig = None self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件 self.platforms: List[RegisteredPlatform] = [] self.llms: List[RegisteredLLM] = [] self.default_personality: dict = None - self.unique_session = False # 独立会话 - self.version: str = None # 机器人版本 - self.nick: tuple = None # gocq 的唤醒词 - self.t2i_mode = False - self.web_search = False # 是否开启了网页搜索 + # self.unique_session = False # 独立会话 + # self.version: str = None # 机器人版本 + # self.nick: tuple = None # gocq 的唤醒词 + # self.t2i_mode = False + # self.web_search = False # 是否开启了网页搜索 self.metrics_uploader = None self.updator: AstrBotUpdator = None @@ -42,7 +42,7 @@ class Context: self.ext_tasks: List[Task] = [] # useless - self.reply_prefix = "" + # self.reply_prefix = "" def register_commands(self, plugin_name: str, diff --git a/util/cmd_config.py b/util/cmd_config.py index 80cac42a1..bb4e56a28 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -1,33 +1,236 @@ import os import json -from type.config import DEFAULT_CONFIG +import logging +from type.config import DEFAULT_CONFIG, DEFAULT_CONFIG_VERSION_2, MAPPINGS_1_2 +from dataclasses import dataclass, field, asdict +from typing import List, Dict, Optional -cpath = "data/cmd_config.json" +ASTRBOT_CONFIG_PATH = "data/cmd_config.json" +logger = logging.getLogger("astrbot") -def check_exist(): - if not os.path.exists(cpath): - with open(cpath, "w", encoding="utf-8-sig") as f: - json.dump({}, f, ensure_ascii=False) - f.flush() +@dataclass +class RateLimit: + time: int = 60 + count: int = 30 + +@dataclass +class PlatformSettings: + unique_session: bool = False + welcome_message_when_join: str = "" + rate_limit: RateLimit = field(default_factory=RateLimit) + reply_prefix: str = "" + forward_threshold: int = 200 + + def __post_init__(self): + self.rate_limit = RateLimit(**self.rate_limit) + +@dataclass +class PlatformConfig(): + name: str = "" + enable: bool = False + +@dataclass +class QQOfficialPlatformConfig(PlatformConfig): + appid: str = "" + secret: str = "" + enable_group_c2c: bool = True + enable_guild_direct_message: bool = True + +@dataclass +class NakuruPlatformConfig(PlatformConfig): + host: str = "172.0.0.1", + port: int = 5700, + websocket_port: int = 6700, + enable_group: bool = True, + enable_guild: bool = True, + enable_direct_message: bool = True, + enable_group_increase: bool = True + +@dataclass +class AiocqhttpPlatformConfig(PlatformConfig): + ws_reverse_host: str = "" + ws_reverse_port: int = 6199 + +@dataclass +class ModelConfig: + model: str = "gpt-4o" + max_tokens: int = 6000 + temperature: float = 0.9 + top_p: float = 1 + frequency_penalty: float = 0 + presence_penalty: float = 0 + +@dataclass +class ImageGenerationModelConfig: + enable: bool = True + model: str = "dall-e-3" + size: str = "1024x1024" + style: str = "vivid" + quality: str = "standard" + +@dataclass +class LLMConfig: + name: str = "openai" + enable: bool = True + key: List[str] = field(default_factory=list) + api_base: str = "" + prompt_prefix: str = "" + default_personality: str = "" + model_config: ModelConfig = field(default_factory=ModelConfig) + image_generation_model_config: ImageGenerationModelConfig = field(default_factory=ImageGenerationModelConfig) + + def __post_init__(self): + self.model_config = ModelConfig(**self.model_config) + self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) + +@dataclass +class LLMSettings: + wake_prefix: str = "" + web_search: bool = False + +@dataclass +class BaiduAIPConfig: + enable: bool = False + app_id: str = "" + api_key: str = "" + secret_key: str = "" + +@dataclass +class InternalKeywordsConfig: + enable: bool = True + extra_keywords: List[str] = field(default_factory=list) + +@dataclass +class ContentSafetyConfig: + baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig) + internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig) + + def __post_init__(self): + self.baidu_aip = BaiduAIPConfig(**self.baidu_aip) + self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords) + +@dataclass +class DashboardConfig: + enable: bool = True + username: str = "" + password: str = "" + +@dataclass +class AstrBotConfig(): + config_version: int = 2 + platform_settings: PlatformSettings = field(default_factory=PlatformSettings) + llm: List[LLMConfig] = field(default_factory=list) + llm_settings: LLMSettings = field(default_factory=LLMSettings) + content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig) + t2i: bool = True + dump_history_interval: int = 10 + admins_id: List[str] = field(default_factory=list) + https_proxy: str = "" + http_proxy: str = "" + dashboard: DashboardConfig = field(default_factory=DashboardConfig) + platform: List[PlatformConfig] = field(default_factory=list) + wake_prefix: List[str] = field(default_factory=list) -class CmdConfig(): def __init__(self) -> None: - self.cached_config: dict = {} self.init_configs() + + # compability + if isinstance(self.wake_prefix, str): + self.wake_prefix = [self.wake_prefix] + + def load_from_dict(self, data: Dict): + '''从字典中加载配置到对象。 + + @note: 适用于 version 2 配置文件。 + ''' + self.config_version=data.get("version", 2) + self.platform=[] + for p in data.get("platform", []): + if 'name' not in p: + logger.warning("A platform config missing name, skipping.") + continue + if p["name"] == "qq_official": + self.platform.append(QQOfficialPlatformConfig(**p)) + elif p["name"] == "nakuru": + self.platform.append(NakuruPlatformConfig(**p)) + elif p["name"] == "aiocqhttp": + self.platform.append(AiocqhttpPlatformConfig(**p)) + else: + self.platform.append(PlatformConfig(**p)) + self.platform_settings=PlatformSettings(**data.get("platform_settings", {})) + self.llm=[LLMConfig(**l) for l in data.get("llm", [])] + self.llm_settings=LLMSettings(**data.get("llm_settings", {})) + self.content_safety=ContentSafetyConfig(**data.get("content_safety", {})) + self.t2i=data.get("t2i", True) + self.dump_history_interval=data.get("dump_history_interval", 10) + self.admins_id=data.get("admins_id", []) + self.https_proxy=data.get("https_proxy", "") + self.http_proxy=data.get("http_proxy", "") + self.dashboard=DashboardConfig(**data.get("dashboard", {})) + self.wake_prefix=data.get("wake_prefix", []) + + def migrate_config_1_2(self, old: dict) -> dict: + '''将配置文件从版本 1 迁移至版本 2''' + logger.info("正在更新配置文件到 version 2...") + new_config = DEFAULT_CONFIG_VERSION_2 + mappings = MAPPINGS_1_2 + + def set_nested_value(d, keys, value): + cursor = d + for key in keys[:-1]: + cursor = cursor[key] + cursor[keys[-1]] = value + + for old_path, new_path in mappings: + value = old + try: + for key in old_path: + value = value[key] # soooooo convenient!! + set_nested_value(new_config, new_path, value) + except KeyError: + # 如果旧配置中没有这个键,跳过,即使用新配置的默认值 + continue + + logger.info("配置文件更新完成。") + return new_config + + def flush_config(self, config: dict = None): + '''将配置写入文件, 如果没有传入配置,则写入默认配置''' + with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: + json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False) + f.flush() def init_configs(self): - ''' - 初始化必需的配置项 - ''' - self.init_config_items(DEFAULT_CONFIG) + '''初始化必需的配置项''' + config = None + + if not self.check_exist(): + self.flush_config() + config = DEFAULT_CONFIG_VERSION_2 + else: + config = self.get_all() + # check if the config is outdated + if 'config_version' not in config: # version 1 + config = self.migrate_config_1_2(config) + self.flush_config(config) + + _tag = False + for key, val in DEFAULT_CONFIG_VERSION_2.items(): + if key not in config: + config[key] = val + _tag = True + if _tag: + with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + f.flush() - @staticmethod - def get(key, default=None): + self.load_from_dict(config) + + def get(self, key, default=None): ''' 从文件系统中直接获取配置 ''' - check_exist() - with open(cpath, "r", encoding="utf-8-sig") as f: + with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: d = json.load(f) if key in d: return d[key] @@ -38,8 +241,7 @@ class CmdConfig(): ''' 从文件系统中获取所有配置 ''' - check_exist() - with open(cpath, "r", encoding="utf-8-sig") as f: + with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: conf_str = f.read() if conf_str.startswith(u'/ufeff'): # remove BOM conf_str = conf_str.encode('utf8')[3:].decode('utf8') @@ -47,21 +249,19 @@ class CmdConfig(): return conf def put(self, key, value): - with open(cpath, "r", encoding="utf-8-sig") as f: + with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: d = json.load(f) d[key] = value - with open(cpath, "w", encoding="utf-8-sig") as f: + with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=2, ensure_ascii=False) f.flush() - - self.cached_config[key] = value - - @staticmethod - def put_by_dot_str(key: str, value): - ''' - 根据点分割的字符串,将值写入配置文件 - ''' - with open(cpath, "r", encoding="utf-8-sig") as f: + + def to_dict(self) -> Dict: + return asdict(self) + + def put_by_dot_str(self, key: str, value): + '''根据点分割的字符串,将值写入配置文件''' + with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f: d = json.load(f) _d = d _ks = key.split(".") @@ -70,23 +270,20 @@ class CmdConfig(): _d[_ks[i]] = value else: _d = _d[_ks[i]] - with open(cpath, "w", encoding="utf-8-sig") as f: + with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f: json.dump(d, f, indent=2, ensure_ascii=False) f.flush() - - def init_config_items(self, d: dict): - conf = self.get_all() + + def update_by_path(self, path: List): + '''根据路径更新配置文件。 - if not self.cached_config: - self.cached_config = conf + 这个方法首先会更新缓存在内存中的配置,然后再写入文件。 + ''' + + for key in path: + if key not in self: + raise KeyError(f"Key {key} not found in config.") - _tag = False - for key, val in d.items(): - if key not in conf: - conf[key] = val - _tag = True - if _tag: - with open(cpath, "w", encoding="utf-8-sig") as f: - json.dump(conf, f, indent=2, ensure_ascii=False) - f.flush() + def check_exist(self) -> bool: + return os.path.exists(ASTRBOT_CONFIG_PATH) \ No newline at end of file diff --git a/util/config_utils.py b/util/config_utils.py index 8fa9a1c46..356bd6b90 100644 --- a/util/config_utils.py +++ b/util/config_utils.py @@ -1,16 +1,15 @@ -import json, os -from util.cmd_config import CmdConfig +import json, os, shutil +import logging + +logger = logging.getLogger("astrbot") def try_migrate_config(): ''' 将 cmd_config.json 迁移至 data/cmd_config.json (如果存在的话) ''' - if os.path.exists("cmd_config.json"): - with open("cmd_config.json", "r", encoding="utf-8-sig") as f: - data = json.load(f) - with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f: - json.dump(data, f, indent=2, ensure_ascii=False) + if os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"): try: - os.remove("cmd_config.json") - except Exception as e: - pass \ No newline at end of file + shutil.move("cmd_config.json", "data/cmd_config.json") + except: + logger.error("迁移 cmd_config.json 失败。AstrBot 将不会读取配置文件,你可以手动将 cmd_config.json 迁移至 data/cmd_config.json。") + \ No newline at end of file diff --git a/util/metrics.py b/util/metrics.py index 172dd7ad4..e905476f8 100644 --- a/util/metrics.py +++ b/util/metrics.py @@ -5,6 +5,7 @@ import sys from type.types import Context from collections import defaultdict +from type.config import VERSION class MetricUploader(): def __init__(self, context: Context) -> None: @@ -49,7 +50,7 @@ class MetricUploader(): try: res = { "stat_version": "moon", - "version": context.version, # 版本号 + "version": VERSION, # 版本号 "platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数 "llm_stats": self.llm_stats, "plugin_stats": self.plugin_stats, diff --git a/util/plugin_dev/api/v1/platform.py b/util/plugin_dev/api/v1/platform.py index 24ad51ba0..55574e0cf 100644 --- a/util/plugin_dev/api/v1/platform.py +++ b/util/plugin_dev/api/v1/platform.py @@ -7,6 +7,6 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。 from model.platform import Platform -from model.platform.qq_nakuru import QQGOCQ +from model.platform.qq_nakuru import QQNakuru from model.platform.qq_official import QQOfficial from model.platform.qq_aiocqhttp import AIOCQHTTP \ No newline at end of file From a7c87642b4ab16b9426c4776a4849be410b0b9a3 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 6 Aug 2024 23:21:18 -0400 Subject: [PATCH 2/5] refactor: Update configuration format and handling --- dashboard/helper.py | 20 ++++++++-- dashboard/server.py | 61 +------------------------------ model/provider/openai_official.py | 6 +-- type/config.py | 4 ++ util/cmd_config.py | 10 +++-- 5 files changed, 32 insertions(+), 69 deletions(-) diff --git a/dashboard/helper.py b/dashboard/helper.py index c9c17f7de..9cef2a4ff 100644 --- a/dashboard/helper.py +++ b/dashboard/helper.py @@ -54,6 +54,20 @@ class DashBoardHelper(): raise ValueError(f"格式校验未通过: {errors}") self.context.config_helper.flush_config(post_config) - def save_extension_config(self, post_config: list, namespace: str): - pass - # update_config(namespace, key, value) \ No newline at end of file + def save_extension_config(self, post_config: dict): + if 'namespace' not in post_config: + raise ValueError("Missing key: namespace") + if 'config' not in post_config: + raise ValueError("Missing key: config") + + namespace = post_config['namespace'] + config: list = post_config['config'][0]['body'] + for item in config: + key = item['path'] + value = item['value'] + typ = item['val_type'] + if typ == 'int': + if not value.isdigit(): + raise ValueError(f"Invalid type for {namespace}.{key}: expected int, got {type(value).__name__}") + value = int(value) + update_config(namespace, key, value) diff --git a/dashboard/server.py b/dashboard/server.py index 074e08776..f507d54cf 100644 --- a/dashboard/server.py +++ b/dashboard/server.py @@ -34,6 +34,7 @@ class AstrBotDashBoard(): self.dashboard_helper = DashBoardHelper(self.context) self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/") + self.dashboard_be.json.sort_keys=False # 不按照字典排序 logging.getLogger('werkzeug').setLevel(logging.ERROR) self.dashboard_be.logger.setLevel(logging.ERROR) @@ -413,66 +414,6 @@ class AstrBotDashBoard(): "body": list(json.load(f).values()) },] - def _generate_outline(self): - ''' - 生成配置大纲。目前分为 platform(消息平台配置) 和 llm(语言模型配置) 两大类。 - 插件的info函数中如果带了plugin_type字段,则会被归类到对应的大纲中。目前仅支持 platform 和 llm 两种类型。 - ''' - outline = [ - { - "type": "platform", - "name": "配置通用消息平台", - "body": [ - { - "title": "通用", - "desc": "通用平台配置", - "namespace": "internal_platform_general", - "tag": "" - }, - { - "title": "QQ(官方)", - "desc": "QQ官方API。支持频道、群、私聊(需获得群权限)", - "namespace": "internal_platform_qq_official", - "tag": "" - }, - { - "title": "QQ(nakuru)", - "desc": "适用于 go-cqhttp", - "namespace": "internal_platform_qq_gocq", - "tag": "" - }, - { - "title": "QQ(aiocqhttp)", - "desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。", - "namespace": "internal_platform_qq_aiocqhttp", - "tag": "" - } - ] - }, - { - "type": "llm", - "name": "配置 LLM", - "body": [ - { - "title": "OpenAI Official", - "desc": "也支持使用官方接口的中转服务", - "namespace": "internal_llm_openai_official", - "tag": "" - } - ] - } - ] - for plugin in self.context.cached_plugins: - for item in outline: - if item['type'] == plugin.metadata.plugin_type: - item['body'].append({ - "title": plugin.metadata.plugin_name, - "desc": plugin.metadata.desc, - "namespace": plugin.metadata.plugin_name, - "tag": plugin.metadata.plugin_name - }) - return outline - async def get_log_history(self): try: with open("logs/astrbot/astrbot.log", "r", encoding="utf-8") as f: diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 69a8cee00..8f2ffb870 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -69,7 +69,8 @@ class ProviderOpenAIOfficial(Provider): base_url=self.base_url ) super().set_curr_model(llm_config.model_config.model) - self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config) + if llm_config.image_generation_model_config: + self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config) self.session_memory: Dict[str, List] = {} # 会话记忆 self.session_memory_lock = threading.Lock() self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小 @@ -427,11 +428,10 @@ class ProviderOpenAIOfficial(Provider): ''' retry = 0 conf = self.image_generator_model_configs - super().accu_model_stat(model=conf['model']) if not conf: logger.error("OpenAI 图片生成模型配置不存在。") raise Exception("OpenAI 图片生成模型配置不存在。") - + super().accu_model_stat(model=conf['model']) while retry < 3: try: images_response = await self.client.images.generate( diff --git a/type/config.py b/type/config.py index 3a44a0b2b..17096ea5c 100644 --- a/type/config.py +++ b/type/config.py @@ -79,6 +79,7 @@ DEFAULT_CONFIG_VERSION_2 = { "config_version": 2, "platform": [ { + "id": "default", "name": "qq_official", "enable": False, "appid": "", @@ -87,6 +88,7 @@ DEFAULT_CONFIG_VERSION_2 = { "enable_guild_direct_message": True, }, { + "id": "default", "name": "nakuru", "enable": False, "host": "172.0.0.1", @@ -98,6 +100,7 @@ DEFAULT_CONFIG_VERSION_2 = { "enable_group_increase": True, }, { + "id": "default", "name": "aiocqhttp", "enable": False, "ws_reverse_host": "", @@ -116,6 +119,7 @@ DEFAULT_CONFIG_VERSION_2 = { }, "llm": [ { + "id": "default", "name": "openai", "enable": True, "key": [], diff --git a/util/cmd_config.py b/util/cmd_config.py index bb4e56a28..befc5e9b8 100644 --- a/util/cmd_config.py +++ b/util/cmd_config.py @@ -26,6 +26,7 @@ class PlatformSettings: @dataclass class PlatformConfig(): + id: str = "" name: str = "" enable: bool = False @@ -70,6 +71,7 @@ class ImageGenerationModelConfig: @dataclass class LLMConfig: + id: str = "" name: str = "openai" enable: bool = True key: List[str] = field(default_factory=list) @@ -77,12 +79,12 @@ class LLMConfig: prompt_prefix: str = "" default_personality: str = "" model_config: ModelConfig = field(default_factory=ModelConfig) - image_generation_model_config: ImageGenerationModelConfig = field(default_factory=ImageGenerationModelConfig) + image_generation_model_config: Optional[ImageGenerationModelConfig] = None def __post_init__(self): self.model_config = ModelConfig(**self.model_config) - self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) - + if self.image_generation_model_config: + self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config) @dataclass class LLMSettings: wake_prefix: str = "" @@ -245,6 +247,8 @@ class AstrBotConfig(): conf_str = f.read() if conf_str.startswith(u'/ufeff'): # remove BOM conf_str = conf_str.encode('utf8')[3:].decode('utf8') + if not conf_str: + return {} conf = json.loads(conf_str) return conf From 6dfbaf1b886e87ff3a504560ca8f8210ed82a924 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 10 Sep 2024 01:57:13 -0400 Subject: [PATCH 3/5] bugfixes --- .coverage | Bin 0 -> 53248 bytes dashboard/server.py | 4 ++-- model/command/internal_handler.py | 13 +++++++------ model/platform/qq_aiocqhttp.py | 2 +- model/platform/qq_nakuru.py | 2 +- model/platform/qq_official.py | 2 +- tests/test_message.py | 9 +++++++-- type/types.py | 2 +- util/cmd_config.py | 7 +++++-- 9 files changed, 25 insertions(+), 16 deletions(-) create mode 100644 .coverage diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..439ed81bb3736293f18c8ea1646ebe595a8e25d8 GIT binary patch literal 53248 zcmeHQdvqMtd7s&t-Pza7T|L)&StI!&$$G40OY#fw+Qfi?1`{~YhQfNa8c7SgJJQZb zwgBs4oIoiqNn1i`32n~F_dhGI_x;eH^gQl+PKtVB@PJ(m z0fqoWfFbbz8-eND{Zg>1ia&JFjP@loGaZd-=`5A6|IF^4ckWj2+_~%4-D);ZZT6~g zwYRG~)wI4}9oEw7Ks=$T@l=027B%CkLDd}6tkR=~){hGvI2$%8!eF2W@-n_Yz9=K#RpSZLT%28x2S1tKuc?>m}X?HXpZ-{cv~Ks z^h&|9GX7DGwkev9!H;iVYak=vlsyAbQA|(wqxF61Xeu_O8EtBGu7xoj>^xv*TbtIC zAyZcq@f2-l%7~kBJ*8>~wAiSr^)K4MHU!PO0KCH1KwbN=LHn|jIb8+lD-sJB9%htTh?!&klF&l!6-f4{29amtBoM6*V<8p43zvFa#I^3;~7! zLx3T`5MT%}1Q-HW9|8jJ;3bUz9pp4e&H`Z2aB9kKqzqm%GLO)&RI0Gm5{JG#8kDhGL&Bkz)relr`$tY!!>1Q-Gg0fqoW zfFZyTUcpd&>7)KGFNH-Y2|wdCNRM_B`wPtVebK%>BIE zbZ?d}NiRx|Nw-Pmt~Xs@cMZGRoadb1a*jB6INjn);^)M$@S!jR%GkvaUh2LZefg53W;_vslhKBrio(&; zage&XTjZvzmKRIv{aOOfdeZY+A$?d&MdJ~DU?3igM-!1DIK-9E(oom+o1iYAow}e2 zWAT1y%Yw?FlJG{T#A&ONMAU@isL9aCNF*B9V*1N7MbGAuR!(i&8BskTSG1EN&px+bnR#^5*7sMI!N3+>As(ZfldkjV~`p zo4H@>Gc-6kITRYy^+7mjI6S_2y}*sGNYu~=b=0|Tr8>0~U`%zcU8&AKc((xRT(eT0 zbKSVTQQ!{RXeT&s$_xQekt|ZQU$!*Z^gzpthofoeRIs>LE!v~z_^=jAYK8%-Bib0K z2e~cvMK#hM&6p%}$LC!ES1jB7?HrJ4Kn^aGJx7A|qEL4mv1Mon=XdpA~s z-ZC4d7wAps;3S1wZ>WHZN>*M`U)0bF>)Bauvj-L!R4|A2l-XO)um*T(3HS~5oSstq zVg=elB?6}{-)=fC=jM-10f9>{FF+k6mh6wr35AF~E!lwt9owlCU7d4g(QOV$Y>|r& zcPe4Id~Tuv$?g84Gom$7F7%X9UbApM}5AH>Kh2o2pW|&4OsmC5LLhu3=m^{%|M^RhN z0y)?L7wlhqd7DbHib&pV@%#T?;VF*PDW{a@ls?y6t`XNagr}U(%Rd)Ba*xV4JJaIl z#g{z|&R>Z;m38D!@&tKB4){;_Z*+g#|1J4Z`33hG_ZI&Fr^oYU@9WO6aLm>24|wnP z{hRMgzEAs{-Xq>}&pn=>NZ*$(OTQ;oNPCwIeAokq07HP|T7 zWpsZ0-*byNT~$PE{#vjgWG{#&yFVpP`)t)^tqF_7|C@G-(@uMpSV8fE__Sh`D)P1-r54ZNhjFIoixc;{T>C z!syCGbMgQB73-Xf|JSWpXDj^YJj?;`R4sy5+XJ2lbxKX{WkU1q2B z0=-wn|CKk2(yh3~MNe|I1b&R-mo4Q_zaEn~ux5`J<#; zNEQ*G4l*zP57^W)7ypyut8>mQEB;rC4|ghIxqLp5l8euX)5MT%}1Q-Gg z0fqoWfFZyTULD~^0X-XU+2e<072Bcz|)MZ%g}tjZ;_!~cf=JO1bV zNBoEVxBA!k{@r&JRI!U8zz|>vFa#I^3;~7!Lx3T`5U`7Yw@2hXl|7d-nGC7zhV-`X z*{9z-<`lUpPUS@U|Bj11)TZ9B$;o@B{hMZwMlNS2-~aQmV}D<_SA;@bUZM(cj2s1^UohlJ$Y#Vx&Qdl6CYlaSy_~>t z!#wBz93SSmo1{*V3v}Kkk`N@+(22`V(9Utk9f1zWXzcg|&SO!F0^#O%$oWkBC!w=w zvYVrgz*_JEd$Wz2g7o?}DEaBO*~{W7k2}sy@sL=>aW4Eh&T*aa!E`6U6IQ1< zx0_QBOU)o3XnsR-H$kGbY3yC6Z#|?|tsi^0EG)X^b#PO=?hnrV;J_X(;P$SC+sd`u z9_B?}S_A38n(Kr>BP1Fd-_JZ9gcgETVRr-Mlr~hKJ|wGBSI=BJ?{L?lRNW^(XqRds6{sEKt7;(8U2|WkUqtBE@|;6OjKD=m zH9S(QCys(6gdN~dQWXkSy&y`fFtKVZQXYiFhTv=mA5b|cg5dLEcO~SOR*pFXt_nyg z7290p$jcR}40xbyX2w;D8KsVmCBO$tKKc0a#Rxd`u+vo8?PPuQ!7*T>=NOpQ>Gfzw&KX$qvni}5A9r_}tmV4mA2G8u|<%_e^ z=R!{KqK94A_}!3O?SAlg0<(NL+$6!>UTNl7$PI3Bofim)DL~A@IRash3&TGL_?`<2 z*18ZzxWpzW$~k8;k=eJu@RszL&nceD+;_-%_!X{k*Q>P5R0GGI;wwZb*e1?qLTB#% z+Qfq=p%+d~p14I4AUhzu(d%_UqSEokLucP>m3h(QS)H0Z`u@d>{G|!}{=Y|gk|P(% zhvb*!XXGS#o&1peBYBa0kNgdJo_vG+1$l=2A-o;%FgZxZh)#w`lzbN66?iMzO|B!G zNhfI`jiiQ@6PdV_jB-KwK>4}yQ{_$N$I7e9OUiNOyYTyiZz^AfHw%6VRI!U8zz|>v zFa#I^3;~7!Lx3T`5MT%}1Q-Gs0P{UPly!GgwrLY(8#hwc)kRrnCuN}!WgQ)qwYO8& z)<#)tD`gutP}b5yS#vXGO-+=oUr*V(b(F1LOWB$=lr=U|*3dwis!~>8Puc3#l-1Qy zR$EJ1O$}w$)s$6LQMPIoWx*h2m6eoLR8Ur4PFYzQWu>K*m6T8x2v9}{Wr{+XOn*PX z`~CFJ=cCN)rOe}@%E@BeeZG-KY!5MT%}1Q-Gg z0fqoWfFZyTUvFa#I^3;~7! zLx3TWLjceJ_bTV`@BcrBcL4l?oFQ+Mx8MxGzmQkSE9Cp+7O$u zG@J!ELLMWJkO?wQOgIm4KhelOau2*S;C46@a3k48t|eQ@Cei`t0@jjxQcWsIfcS`u z@X96SJSbxqLx3T`5MT%}1Q-Gg0fqoWfFZyTU bool: return os.path.exists(ASTRBOT_CONFIG_PATH) \ No newline at end of file From 5fc4693b9ca6a66dc3721543b59037232f942fe6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 10 Sep 2024 01:57:51 -0400 Subject: [PATCH 4/5] remove: .coverage --- .coverage | Bin 53248 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .coverage diff --git a/.coverage b/.coverage deleted file mode 100644 index 439ed81bb3736293f18c8ea1646ebe595a8e25d8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeHQdvqMtd7s&t-Pza7T|L)&StI!&$$G40OY#fw+Qfi?1`{~YhQfNa8c7SgJJQZb zwgBs4oIoiqNn1i`32n~F_dhGI_x;eH^gQl+PKtVB@PJ(m z0fqoWfFbbz8-eND{Zg>1ia&JFjP@loGaZd-=`5A6|IF^4ckWj2+_~%4-D);ZZT6~g zwYRG~)wI4}9oEw7Ks=$T@l=027B%CkLDd}6tkR=~){hGvI2$%8!eF2W@-n_Yz9=K#RpSZLT%28x2S1tKuc?>m}X?HXpZ-{cv~Ks z^h&|9GX7DGwkev9!H;iVYak=vlsyAbQA|(wqxF61Xeu_O8EtBGu7xoj>^xv*TbtIC zAyZcq@f2-l%7~kBJ*8>~wAiSr^)K4MHU!PO0KCH1KwbN=LHn|jIb8+lD-sJB9%htTh?!&klF&l!6-f4{29amtBoM6*V<8p43zvFa#I^3;~7! zLx3T`5MT%}1Q-HW9|8jJ;3bUz9pp4e&H`Z2aB9kKqzqm%GLO)&RI0Gm5{JG#8kDhGL&Bkz)relr`$tY!!>1Q-Gg0fqoW zfFZyTUcpd&>7)KGFNH-Y2|wdCNRM_B`wPtVebK%>BIE zbZ?d}NiRx|Nw-Pmt~Xs@cMZGRoadb1a*jB6INjn);^)M$@S!jR%GkvaUh2LZefg53W;_vslhKBrio(&; zage&XTjZvzmKRIv{aOOfdeZY+A$?d&MdJ~DU?3igM-!1DIK-9E(oom+o1iYAow}e2 zWAT1y%Yw?FlJG{T#A&ONMAU@isL9aCNF*B9V*1N7MbGAuR!(i&8BskTSG1EN&px+bnR#^5*7sMI!N3+>As(ZfldkjV~`p zo4H@>Gc-6kITRYy^+7mjI6S_2y}*sGNYu~=b=0|Tr8>0~U`%zcU8&AKc((xRT(eT0 zbKSVTQQ!{RXeT&s$_xQekt|ZQU$!*Z^gzpthofoeRIs>LE!v~z_^=jAYK8%-Bib0K z2e~cvMK#hM&6p%}$LC!ES1jB7?HrJ4Kn^aGJx7A|qEL4mv1Mon=XdpA~s z-ZC4d7wAps;3S1wZ>WHZN>*M`U)0bF>)Bauvj-L!R4|A2l-XO)um*T(3HS~5oSstq zVg=elB?6}{-)=fC=jM-10f9>{FF+k6mh6wr35AF~E!lwt9owlCU7d4g(QOV$Y>|r& zcPe4Id~Tuv$?g84Gom$7F7%X9UbApM}5AH>Kh2o2pW|&4OsmC5LLhu3=m^{%|M^RhN z0y)?L7wlhqd7DbHib&pV@%#T?;VF*PDW{a@ls?y6t`XNagr}U(%Rd)Ba*xV4JJaIl z#g{z|&R>Z;m38D!@&tKB4){;_Z*+g#|1J4Z`33hG_ZI&Fr^oYU@9WO6aLm>24|wnP z{hRMgzEAs{-Xq>}&pn=>NZ*$(OTQ;oNPCwIeAokq07HP|T7 zWpsZ0-*byNT~$PE{#vjgWG{#&yFVpP`)t)^tqF_7|C@G-(@uMpSV8fE__Sh`D)P1-r54ZNhjFIoixc;{T>C z!syCGbMgQB73-Xf|JSWpXDj^YJj?;`R4sy5+XJ2lbxKX{WkU1q2B z0=-wn|CKk2(yh3~MNe|I1b&R-mo4Q_zaEn~ux5`J<#; zNEQ*G4l*zP57^W)7ypyut8>mQEB;rC4|ghIxqLp5l8euX)5MT%}1Q-Gg z0fqoWfFZyTULD~^0X-XU+2e<072Bcz|)MZ%g}tjZ;_!~cf=JO1bV zNBoEVxBA!k{@r&JRI!U8zz|>vFa#I^3;~7!Lx3T`5U`7Yw@2hXl|7d-nGC7zhV-`X z*{9z-<`lUpPUS@U|Bj11)TZ9B$;o@B{hMZwMlNS2-~aQmV}D<_SA;@bUZM(cj2s1^UohlJ$Y#Vx&Qdl6CYlaSy_~>t z!#wBz93SSmo1{*V3v}Kkk`N@+(22`V(9Utk9f1zWXzcg|&SO!F0^#O%$oWkBC!w=w zvYVrgz*_JEd$Wz2g7o?}DEaBO*~{W7k2}sy@sL=>aW4Eh&T*aa!E`6U6IQ1< zx0_QBOU)o3XnsR-H$kGbY3yC6Z#|?|tsi^0EG)X^b#PO=?hnrV;J_X(;P$SC+sd`u z9_B?}S_A38n(Kr>BP1Fd-_JZ9gcgETVRr-Mlr~hKJ|wGBSI=BJ?{L?lRNW^(XqRds6{sEKt7;(8U2|WkUqtBE@|;6OjKD=m zH9S(QCys(6gdN~dQWXkSy&y`fFtKVZQXYiFhTv=mA5b|cg5dLEcO~SOR*pFXt_nyg z7290p$jcR}40xbyX2w;D8KsVmCBO$tKKc0a#Rxd`u+vo8?PPuQ!7*T>=NOpQ>Gfzw&KX$qvni}5A9r_}tmV4mA2G8u|<%_e^ z=R!{KqK94A_}!3O?SAlg0<(NL+$6!>UTNl7$PI3Bofim)DL~A@IRash3&TGL_?`<2 z*18ZzxWpzW$~k8;k=eJu@RszL&nceD+;_-%_!X{k*Q>P5R0GGI;wwZb*e1?qLTB#% z+Qfq=p%+d~p14I4AUhzu(d%_UqSEokLucP>m3h(QS)H0Z`u@d>{G|!}{=Y|gk|P(% zhvb*!XXGS#o&1peBYBa0kNgdJo_vG+1$l=2A-o;%FgZxZh)#w`lzbN66?iMzO|B!G zNhfI`jiiQ@6PdV_jB-KwK>4}yQ{_$N$I7e9OUiNOyYTyiZz^AfHw%6VRI!U8zz|>v zFa#I^3;~7!Lx3T`5MT%}1Q-Gs0P{UPly!GgwrLY(8#hwc)kRrnCuN}!WgQ)qwYO8& z)<#)tD`gutP}b5yS#vXGO-+=oUr*V(b(F1LOWB$=lr=U|*3dwis!~>8Puc3#l-1Qy zR$EJ1O$}w$)s$6LQMPIoWx*h2m6eoLR8Ur4PFYzQWu>K*m6T8x2v9}{Wr{+XOn*PX z`~CFJ=cCN)rOe}@%E@BeeZG-KY!5MT%}1Q-Gg z0fqoWfFZyTUvFa#I^3;~7! zLx3TWLjceJ_bTV`@BcrBcL4l?oFQ+Mx8MxGzmQkSE9Cp+7O$u zG@J!ELLMWJkO?wQOgIm4KhelOau2*S;C46@a3k48t|eQ@Cei`t0@jjxQcWsIfcS`u z@X96SJSbxqLx3T`5MT%}1Q-Gg0fqoWfFZyTU Date: Tue, 10 Sep 2024 03:31:17 -0400 Subject: [PATCH 5/5] =?UTF-8?q?perf:=20=E5=AE=8C=E5=96=84=E8=A6=86?= =?UTF-8?q?=E7=9B=96=E7=8E=87=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/coverage_test.yml | 6 ++- .gitignore | 3 +- astrbot/message/handler.py | 5 +- model/command/internal_handler.py | 11 ++-- model/platform/qq_aiocqhttp.py | 13 +++-- model/provider/openai_official.py | 2 +- tests/mocks/onebot.py | 7 ++- tests/mocks/qq_official.py | 23 +++++--- tests/test_message.py | 82 ++++++++++++++++++++++++++++- 9 files changed, 125 insertions(+), 27 deletions(-) diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml index a021daa7c..941084087 100644 --- a/.github/workflows/coverage_test.yml +++ b/.github/workflows/coverage_test.yml @@ -26,7 +26,11 @@ jobs: mkdir temp - name: Run tests - run: PYTHONPATH=./ pytest --cov=. tests/ -v + run: | + export LLM_MODEL=${{ secrets.LLM_MODEL }} + export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }} + export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} + PYTHONPATH=./ pytest --cov=. tests/ -v - name: Upload results to Codecov uses: codecov/codecov-action@v4 diff --git a/.gitignore b/.gitignore index 7815d9b88..91514b8d7 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ cmd_config.json data/* cookies.json logs/ -addons/plugins \ No newline at end of file +addons/plugins +.coverage \ No newline at end of file diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 501960825..b8bce1f99 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -110,8 +110,7 @@ class MessageHandler(): self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix if self.llm_wake_prefix: self.llm_wake_prefix = self.llm_wake_prefix.strip() - self.nicks = self.context.config_helper.wake_prefix - self.provider = self.context.llms[0] if len(self.context.llms) > 0 else None + self.provider = self.context.llms[0].llm_instance if len(self.context.llms) > 0 else None self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix) self.llm_tools = FuncCall(self.provider) @@ -140,7 +139,7 @@ class MessageHandler(): return # remove the nick prefix - for nick in self.nicks: + for nick in self.context.config_helper.wake_prefix: if msg_plain.startswith(nick): msg_plain = msg_plain.removeprefix(nick) break diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index e4e120f79..058987e4b 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -62,11 +62,12 @@ class InternalCommandHandler: return CommandResult().message("你没有权限使用该指令。") l = message_str.split(" ") if len(l) == 1: - return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.config_helper.wake_prefix}") + return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}") nick = l[1].strip() if not nick: return CommandResult().message("wake: 请指定唤醒词。") context.config_helper.wake_prefix = [nick] + context.config_helper.save_config() return CommandResult( hit=True, success=True, @@ -88,11 +89,7 @@ class InternalCommandHandler: ret = f"当前已经是最新版本 v{VERSION}。" else: ret = f"发现新版本 {update_info.version},更新内容如下:\n---\n{update_info.body}\n---\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。" - return CommandResult( - hit=True, - success=False, - message_chain=ret, - ) + return CommandResult().message(ret) else: if tokens.get(1) == "latest": try: @@ -182,7 +179,7 @@ class InternalCommandHandler: async with session.get("https://soulter.top/channelbot/notice.json") as resp: notice = (await resp.json())["notice"] except BaseException as e: - logger.warn("An error occurred while fetching astrbot notice. Never mind, it's not important.") + logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.") msg = "# Help Center\n## 指令列表\n" for key, value in self.manager.commands_handler.items(): diff --git a/model/platform/qq_aiocqhttp.py b/model/platform/qq_aiocqhttp.py index c6f4164f8..e97e5baa6 100644 --- a/model/platform/qq_aiocqhttp.py +++ b/model/platform/qq_aiocqhttp.py @@ -112,7 +112,7 @@ class AIOCQHTTP(Platform): while self.context.running: await asyncio.sleep(1) - def pre_check(self, message: AstrBotMessage) -> bool: + async def pre_check(self, message: AstrBotMessage) -> bool: # if message chain contains Plain components or # At components which points to self_id, return True if message.type == MessageType.FRIEND_MESSAGE: @@ -121,7 +121,7 @@ class AIOCQHTTP(Platform): if isinstance(comp, At) and str(comp.qq) == message.self_id: return True, "at" # check commands which ignore prefix - if self.context.command_manager.check_command_ignore_prefix(message.message_str): + if await self.context.command_manager.check_command_ignore_prefix(message.message_str): return True, "command" # check nicks if self.check_nick(message.message_str): @@ -132,7 +132,7 @@ class AIOCQHTTP(Platform): logger.info( f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") - ok, reason = self.pre_check(message) + ok, reason = await self.pre_check(message) if not ok: return @@ -173,6 +173,8 @@ class AIOCQHTTP(Platform): # 如果是等待回复的消息 if message.session_id in self.waiting and self.waiting[message.session_id] == '': self.waiting[message.session_id] = message + + return message_result async def reply_msg(self, @@ -188,17 +190,18 @@ class AIOCQHTTP(Platform): res = [Plain(text=res), ] # if image mode, put all Plain texts into a new picture. - if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list): + if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(result_message, list): rendered_images = await self.convert_to_t2i_chain(res) if rendered_images: try: await self._reply(message, rendered_images) - return + return rendered_images except BaseException as e: logger.warn(traceback.format_exc()) logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。") await self._reply(message, res) + return res async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]): await self.record_metrics() diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index e1bf13ab0..2b6e08544 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -355,10 +355,10 @@ class ProviderOpenAIOfficial(Provider): if ok: continue else: raise Exception("所有 OpenAI API Key 目前都不可用。") except BadRequestError as e: + retry += 1 logger.warn(f"OpenAI 请求异常:{e}。") if "image_url is only supported by certain models." in str(e): raise Exception(f"当前模型 { self.get_curr_model() } 不支持图片输入,请更换模型。") - retry += 1 except RateLimitError as e: if "You exceeded your current quota" in str(e): self.keys_data[self.chosen_api_key] = False diff --git a/tests/mocks/onebot.py b/tests/mocks/onebot.py index 66df3d1ee..1b204c507 100644 --- a/tests/mocks/onebot.py +++ b/tests/mocks/onebot.py @@ -1,3 +1,4 @@ +import copy from aiocqhttp import Event class MockOneBotMessage(): @@ -10,4 +11,8 @@ class MockOneBotMessage(): return self.group_event_sample def create_random_direct_message(self): - return self.friend_event_sample \ No newline at end of file + return self.friend_event_sample + + def create_msg(self, text: str): + self.group_event_sample.message = [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': text}, 'type': 'text'}] + return self.group_event_sample \ No newline at end of file diff --git a/tests/mocks/qq_official.py b/tests/mocks/qq_official.py index 0d665d289..0978502aa 100644 --- a/tests/mocks/qq_official.py +++ b/tests/mocks/qq_official.py @@ -3,19 +3,19 @@ import botpy.message class MockQQOfficialMessage(): def __init__(self): # 这些数据已经经过去敏处理 - self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'} - self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'} - self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'} + self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T19:58:52+08:00'} + self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:06:32+08:00'} + self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:15:24+08:00'} self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849" self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'} - self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'} - self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'} + self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'} + self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'} self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'} - self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'} - self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'} + self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'} + self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'} self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64" def create_random_group_message(self): @@ -42,4 +42,13 @@ class MockQQOfficialMessage(): ) return mocked + def create_msg(self, text: str): + sample = self.group_plain_text_sample.copy() + sample['content'] = text + mocked = botpy.message.Message( + api=None, + event_id=self.group_event_id_sample, + data=sample + ) + return mocked diff --git a/tests/test_message.py b/tests/test_message.py index 809b87164..7015785e1 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -8,6 +8,7 @@ from tests.mocks.onebot import MockOneBotMessage from astrbot.bootstrap import AstrBotBootstrap from model.platform.qq_official import QQOfficial from model.platform.qq_aiocqhttp import AIOCQHTTP +from model.provider.openai_official import ProviderOpenAIOfficial from type.astrbot_message import * from type.message_event import * from SparkleLogging.utils.core import LogManager @@ -24,7 +25,17 @@ pytest_plugins = ('pytest_asyncio',) os.environ['TEST_MODE'] = 'on' bootstrap = AstrBotBootstrap() + +llm_config = bootstrap.context.config_helper.llm[0] +llm_config.api_base = os.environ['OPENAI_API_BASE'] +llm_config.key = [os.environ['OPENAI_API_KEY']] +llm_config.model_config.model = os.environ['LLM_MODEL'] +llm_config.model_config.max_tokens = 1000 +llm_provider = ProviderOpenAIOfficial(llm_config) asyncio.run(bootstrap.run()) +bootstrap.message_handler.provider = llm_provider +bootstrap.config_helper.wake_prefix = ["/"] +bootstrap.config_helper.admins_id = ["905617992"] for p_config in bootstrap.context.config_helper.platform: if isinstance(p_config, QQOfficialPlatformConfig): @@ -67,4 +78,73 @@ class TestBasicMessageHandle(): event = MockOneBotMessage().create_random_direct_message() abm = aiocqhttp.convert_message(event) ret = await aiocqhttp.handle_msg(abm) - print(ret) \ No newline at end of file + print(ret) + +class TestInteralCommandHsandle(): + def create(self, text: str): + event = MockOneBotMessage().create_msg(text) + abm = aiocqhttp.convert_message(event) + return abm + + async def fast_test(self, text: str): + abm = self.create(text) + ret = await aiocqhttp.handle_msg(abm) + print(f"Command: {text}, Result: {ret.result_message}") + return ret + + @pytest.mark.asyncio + async def test_config_save(self): + abm = self.create("/websearch on") + ret = await aiocqhttp.handle_msg(abm) + assert bootstrap.context.config_helper.llm_settings.web_search \ + == bootstrap.config_helper.get("llm_settings")['web_search'] + + @pytest.mark.asyncio + async def test_websearch(self): + await self.fast_test("/websearch") + await self.fast_test("/websearch on") + await self.fast_test("/websearch off") + + @pytest.mark.asyncio + async def test_help(self): + await self.fast_test("/help") + + @pytest.mark.asyncio + async def test_myid(self): + await self.fast_test("/myid") + + @pytest.mark.asyncio + async def test_wake(self): + await self.fast_test("/wake") + await self.fast_test("/wake #") + assert "#" in bootstrap.context.config_helper.wake_prefix + assert "#" in bootstrap.context.config_helper.get("wake_prefix") + await self.fast_test("#wake /") + + @pytest.mark.asyncio + async def test_sleep(self): + await self.fast_test("/provider") + + @pytest.mark.asyncio + async def test_update(self): + await self.fast_test("/update") + + @pytest.mark.asyncio + async def test_t2i(self): + if not bootstrap.context.config_helper.t2i: + abm = self.create("/t2i") + await aiocqhttp.handle_msg(abm) + await self.fast_test("/help") + +class TestLLMChat(): + @pytest.mark.asyncio + async def test_llm_chat(self): + os.environ["TEST_LLM"] = "on" + ret = await llm_provider.text_chat("Just reply `ok`", "test") + print(ret) + event = MockOneBotMessage().create_msg("Just reply `ok`") + abm = aiocqhttp.convert_message(event) + ret = await aiocqhttp.handle_msg(abm) + print(ret) + os.environ["TEST_LLM"] = "off" + \ No newline at end of file