feat: 重构配置格式

perf: 优化配置处理过程和呈现方式
This commit is contained in:
Soulter
2024-08-06 04:58:29 -04:00
parent 14dbdb2d83
commit f8aef78d25
19 changed files with 829 additions and 833 deletions
+24 -29
View File
@@ -13,7 +13,7 @@ from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.cmd_config import CmdConfig
from util.cmd_config import AstrBotConfig
from util.metrics import MetricUploader
from util.config_utils import *
from util.updator.astrbot_updator import AstrBotUpdator
@@ -24,29 +24,15 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap():
def __init__(self) -> None:
self.context = Context()
self.config_helper = CmdConfig()
# load configs and ensure the backward compatibility
try_migrate_config()
self.config_helper = AstrBotConfig()
self.context.config_helper = self.config_helper
self.context.base_config = self.config_helper.cached_config
self.context.default_personality = {
"name": "default",
"prompt": self.context.base_config.get("default_personality_str", ""),
}
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
self.context.nick = nick_qq
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
self.context.version = VERSION
logger.info("AstrBot v" + self.context.version)
logger.info("AstrBot v" + VERSION)
# apply proxy settings
http_proxy = self.context.base_config.get("http_proxy")
https_proxy = self.context.base_config.get("https_proxy")
http_proxy = self.context.config_helper.http_proxy
https_proxy = self.context.config_helper.https_proxy
if http_proxy:
os.environ['HTTP_PROXY'] = http_proxy
if https_proxy:
@@ -66,10 +52,9 @@ class AstrBotBootstrap():
self.db_conn_helper = dbConn()
# load llm provider
self.llm_instance: Provider = None
self.load_llm()
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance)
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper)
self.platfrom_manager = PlatformManager(self.context, self.message_handler)
self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator)
self.metrics_uploader = MetricUploader(self.context)
@@ -105,16 +90,26 @@ class AstrBotBootstrap():
await asyncio.sleep(5)
def load_llm(self):
if 'openai' in self.config_helper.cached_config and \
len(self.config_helper.cached_config['openai']['key']) and \
self.config_helper.cached_config['openai']['key'][0] is not None:
from model.provider.openai_official import ProviderOpenAIOfficial
f = False
llms = self.context.config_helper.llm
logger.info(f"加载 {len(llms)} 个 LLM Provider...")
for llm in llms:
if llm.enable:
if llm.name == "openai" and llm.key and llm.enable:
self.load_openai(llm)
f = True
logger.info(f"已启用 OpenAI API 支持。")
else:
logger.warn(f"未知的 LLM Provider: {llm.name}")
if f:
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
self.llm_instance = ProviderOpenAIOfficial(self.context)
self.openai_command_handler.set_provider(self.llm_instance)
self.context.register_provider("internal_openai", self.llm_instance)
logger.info("已启用 OpenAI API 支持。")
self.openai_command_handler.set_provider(self.context.llms[0].llm_instance)
def load_openai(self, llm_config):
from model.provider.openai_official import ProviderOpenAIOfficial
inst = ProviderOpenAIOfficial(llm_config)
self.context.register_provider("internal_openai", inst)
def load_plugins(self):
self.plugin_manager.plugin_reload()
+8 -9
View File
@@ -1,16 +1,15 @@
from aip import AipContentCensor
from util.cmd_config import BaiduAIPConfig
class BaiduJudge:
def __init__(self, baidu_configs) -> None:
if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs:
self.app_id = str(baidu_configs['app_id'])
self.api_key = baidu_configs['api_key']
self.secret_key = baidu_configs['secret_key']
self.client = AipContentCensor(
self.app_id, self.api_key, self.secret_key)
else:
raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!")
def __init__(self, baidu_configs: BaiduAIPConfig) -> None:
self.app_id = baidu_configs.app_id
self.api_key = baidu_configs.api_key
self.secret_key = baidu_configs.secret_key
self.client = AipContentCensor(self.app_id,
self.api_key,
self.secret_key)
def judge(self, text):
res = self.client.textCensorUserDefined(text)
+13 -20
View File
@@ -22,16 +22,11 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class RateLimitHelper():
def __init__(self, context: Context) -> None:
self.user_rate_limit: Dict[int, int] = {}
self.rate_limit_time: int = 60
self.rate_limit_count: int = 10
rl = context.config_helper.platform_settings.rate_limit
self.rate_limit_time: int = rl.time
self.rate_limit_count: int = rl.count
self.user_frequency = {}
if 'limit' in context.base_config:
if 'count' in context.base_config['limit']:
self.rate_limit_count = context.base_config['limit']['count']
if 'time' in context.base_config['limit']:
self.rate_limit_time = context.base_config['limit']['time']
def check_frequency(self, session_id: str) -> bool:
'''
检查发言频率
@@ -56,12 +51,11 @@ class RateLimitHelper():
class ContentSafetyHelper():
def __init__(self, context: Context) -> None:
self.baidu_judge = None
if 'baidu_api' in context.base_config and \
'enable' in context.base_config['baidu_aip'] and \
context.base_config['baidu_aip']['enable']:
aip = context.config_helper.content_safety.baidu_aip
if aip.enable:
try:
from astrbot.message.baidu_aip_judge import BaiduJudge
self.baidu_judge = BaiduJudge(context.base_config['baidu_aip'])
self.baidu_judge = BaiduJudge(aip)
logger.info("已启用百度 AI 内容审核。")
except BaseException as e:
logger.error("百度 AI 内容审核初始化失败。")
@@ -104,19 +98,18 @@ class ContentSafetyHelper():
class MessageHandler():
def __init__(self, context: Context,
command_manager: CommandManager,
persist_manager: dbConn,
provider: Provider) -> None:
persist_manager: dbConn) -> None:
self.context = context
self.command_manager = command_manager
self.persist_manager = persist_manager
self.rate_limit_helper = RateLimitHelper(context)
self.content_safety_helper = ContentSafetyHelper(context)
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix
if self.llm_wake_prefix:
self.llm_wake_prefix = self.llm_wake_prefix.strip()
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = str(self.context.reply_prefix)
self.nicks = self.context.config_helper.wake_prefix
self.provider = self.context.llms[0] if len(self.context.llms) > 0 else None
self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix)
def set_provider(self, provider: Provider):
self.provider = provider
@@ -176,7 +169,7 @@ class MessageHandler():
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
break
web_search = self.context.web_search
web_search = self.context.config_helper.llm_settings.web_search
if not web_search and msg_plain.startswith("ws"):
# leverage web search feature
web_search = True
-1
View File
@@ -2,7 +2,6 @@ from dataclasses import dataclass
class DashBoardData():
stats: dict = {}
configs: dict = {}
@dataclass
class Response():
+45 -523
View File
@@ -1,537 +1,59 @@
import threading
import asyncio
from . import DashBoardData
from typing import Union, Optional
from util.cmd_config import CmdConfig
from dataclasses import dataclass
from util.cmd_config import AstrBotConfig
from dataclasses import dataclass, asdict
from util.plugin_dev.api.v1.config import update_config
from SparkleLogging.utils.core import LogManager
from logging import Logger
from type.types import Context
from type.config import CONFIG_METADATA_2
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@dataclass
class DashBoardConfig():
config_type: str
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None # 仅 item 才需要
body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要
value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要
val_type: Optional[str] = None # 仅 item 才需要
class DashBoardHelper():
def __init__(self, context: Context, dashboard_data: DashBoardData):
dashboard_data.configs = {
"data": []
}
def __init__(self, context: Context):
self.context = context
self.parse_default_config(dashboard_data, context.base_config)
self.config_key_dont_show = ['dashboard', 'config_version']
# 将 config.yaml、 中的配置解析到 dashboard_data.configs 中
def parse_default_config(self, dashboard_data: DashBoardData, config: dict):
try:
qq_official_platform_group = DashBoardConfig(
config_type="group",
name="QQ(官方)",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用 QQ_OFFICIAL 平台",
description="官方的接口,仅支持 QQ 频道。详见 q.qq.com",
value=config['qqbot']['enable'],
path="qqbot.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="QQ机器人APPID",
description="详见 q.qq.com",
value=config['qqbot']['appid'],
path="qqbot.appid",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="QQ机器人令牌",
description="详见 q.qq.com",
value=config['qqbot']['token'],
path="qqbot.token",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="QQ机器人 Secret",
description="详见 q.qq.com",
value=config['qqbot_secret'],
path="qqbot_secret",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否允许 QQ 频道私聊",
description="如果启用,机器人会响应私聊消息。",
value=config['direct_message_mode'],
path="direct_message_mode",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否接收QQ群消息",
description="需要机器人有相应的群消息接收权限。在 q.qq.com 上查看。",
value=config['qqofficial_enable_group_message'],
path="qqofficial_enable_group_message",
),
]
)
qq_gocq_platform_group = DashBoardConfig(
config_type="group",
name="QQ(nakuru)",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用",
description="",
value=config['gocqbot']['enable'],
path="gocqbot.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="HTTP 服务器地址",
description="",
value=config['gocq_host'],
path="gocq_host",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="HTTP 服务器端口",
description="",
value=config['gocq_http_port'],
path="gocq_http_port",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="WebSocket 服务器端口",
description="目前仅支持正向 WebSocket",
value=config['gocq_websocket_port'],
path="gocq_websocket_port",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应群消息",
description="",
value=config['gocq_react_group'],
path="gocq_react_group",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应私聊消息",
description="",
value=config['gocq_react_friend'],
path="gocq_react_friend",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应群成员增加消息",
description="",
value=config['gocq_react_group_increase'],
path="gocq_react_group_increase",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应频道消息",
description="",
value=config['gocq_react_guild'],
path="gocq_react_guild",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="转发阈值(字符数)",
description="机器人回复的消息长度超出这个值后,会被折叠成转发卡片发出以减少刷屏。",
value=config['qq_forward_threshold'],
path="qq_forward_threshold",
),
]
)
qq_aiocqhttp_platform_group = DashBoardConfig(
config_type="group",
name="QQ(aiocqhttp)",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用",
description="",
value=config['aiocqhttp']['enable'],
path="aiocqhttp.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="WebSocket 反向连接 host",
description="",
value=config['aiocqhttp']['ws_reverse_host'],
path="aiocqhttp.ws_reverse_host",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="WebSocket 反向连接 port",
description="",
value=config['aiocqhttp']['ws_reverse_port'],
path="aiocqhttp.ws_reverse_port",
),
]
)
general_platform_detail_group = DashBoardConfig(
config_type="group",
name="通用平台配置",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启动消息文字转图片",
description="启动后,机器人会将消息转换为图片发送,以降低风控风险。",
value=config['qq_pic_mode'],
path="qq_pic_mode",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="消息限制时间",
description="在此时间内,机器人不会回复同一个用户的消息。单位:秒",
value=config['limit']['time'],
path="limit.time",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="消息限制次数",
description="在上面的时间内,如果用户发送消息超过此次数,则机器人不会回复。单位:次",
value=config['limit']['count'],
path="limit.count",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="回复前缀",
description="[xxxx] 你好! 其中xxxx是你可以填写的前缀。如果为空则不显示。",
value=config['reply_prefix'],
path="reply_prefix",
),
DashBoardConfig(
config_type="item",
val_type="list",
name="通用管理员用户 ID(支持多个管理员)。通过 !myid 指令获取。",
description="",
value=config['other_admins'],
path="other_admins",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="独立会话",
description="是否启用独立会话模式,即 1 个用户自然账号 1 个会话。",
value=config['uniqueSessionMode'],
path="uniqueSessionMode",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="LLM 唤醒词",
description="如果不为空, 那么只有当消息以此词开头时,才会调用大语言模型进行回复。如设置为 /chat,那么只有当消息以 /chat 开头时,才会调用大语言模型进行回复。",
value=config['llm_wake_prefix'],
path="llm_wake_prefix",
)
]
)
openai_official_llm_group = DashBoardConfig(
config_type="group",
name="OpenAI 官方接口类设置",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="list",
name="OpenAI API Key",
description="OpenAI API 的 Key。支持使用非官方但兼容的 API(第三方中转key)。",
value=config['openai']['key'],
path="openai.key",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI API 节点地址(api base)",
description="OpenAI API 的节点地址,配合非官方 API 使用。如果不想填写,那么请填写 none",
value=config['openai']['api_base'],
path="openai.api_base",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI model",
description="OpenAI LLM 模型。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['model'],
path="openai.chatGPTConfigs.model",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="OpenAI max_tokens",
description="OpenAI 最大生成长度。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['max_tokens'],
path="openai.chatGPTConfigs.max_tokens",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI temperature",
description="OpenAI 温度。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['temperature'],
path="openai.chatGPTConfigs.temperature",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI top_p",
description="OpenAI top_p。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['top_p'],
path="openai.chatGPTConfigs.top_p",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI frequency_penalty",
description="OpenAI frequency_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['frequency_penalty'],
path="openai.chatGPTConfigs.frequency_penalty",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI presence_penalty",
description="OpenAI presence_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['presence_penalty'],
path="openai.chatGPTConfigs.presence_penalty",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="OpenAI 总生成长度限制",
description="OpenAI 总生成长度限制。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['total_tokens_limit'],
path="openai.total_tokens_limit",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成模型",
description="OpenAI 图像生成模型。",
value=config['openai_image_generate']['model'],
path="openai_image_generate.model",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成大小",
description="OpenAI 图像生成大小。",
value=config['openai_image_generate']['size'],
path="openai_image_generate.size",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成风格",
description="OpenAI 图像生成风格。修改前请参考 OpenAI 官方文档",
value=config['openai_image_generate']['style'],
path="openai_image_generate.style",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成质量",
description="OpenAI 图像生成质量。修改前请参考 OpenAI 官方文档",
value=config['openai_image_generate']['quality'],
path="openai_image_generate.quality",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="问题题首提示词",
description="如果填写了此项,在每个对大语言模型的请求中,都会在问题前加上此提示词。",
value=config['llm_env_prompt'],
path="llm_env_prompt",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="默认人格文本",
description="默认人格文本",
value=config['default_personality_str'],
path="default_personality_str",
),
]
)
baidu_aip_group = DashBoardConfig(
config_type="group",
name="百度内容审核",
description="需要去申请",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启动百度内容审核服务",
description="",
value=config['baidu_aip']['enable'],
path="baidu_aip.enable"
),
DashBoardConfig(
config_type="item",
val_type="str",
name="APP ID",
description="",
value=config['baidu_aip']['app_id'],
path="baidu_aip.app_id"
),
DashBoardConfig(
config_type="item",
val_type="str",
name="API KEY",
description="",
value=config['baidu_aip']['api_key'],
path="baidu_aip.api_key"
),
DashBoardConfig(
config_type="item",
val_type="str",
name="SECRET KEY",
description="",
value=config['baidu_aip']['secret_key'],
path="baidu_aip.secret_key"
)
]
)
other_group = DashBoardConfig(
config_type="group",
name="其他配置",
description="其他配置描述",
body=[
DashBoardConfig(
config_type="item",
val_type="str",
name="HTTP 代理地址",
description="建议上下一致",
value=config['http_proxy'],
path="http_proxy",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="HTTPS 代理地址",
description="建议上下一致",
value=config['https_proxy'],
path="https_proxy",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="面板用户名",
description="是的,就是你理解的这个面板的用户名",
value=config['dashboard_username'],
path="dashboard_username",
),
]
)
dashboard_data.configs['data'] = [
qq_official_platform_group,
qq_gocq_platform_group,
general_platform_detail_group,
openai_official_llm_group,
other_group,
baidu_aip_group,
qq_aiocqhttp_platform_group
]
except Exception as e:
logger.error(f"配置文件解析错误:{e}")
raise e
def save_config(self, post_config: list, namespace: str):
'''
根据 path 解析并保存配置
'''
queue = post_config
while len(queue) > 0:
config = queue.pop(0)
if config['config_type'] == "group":
for item in config['body']:
queue.append(item)
elif config['config_type'] == "item":
if config['path'] is None or config['path'] == "":
def validate_config(self, data):
errors = []
# 递归验证数据
def validate(data, path=""):
for key, meta in CONFIG_METADATA_2.items():
if key not in data:
if key not in self.config_key_dont_show:
# 这些key不会传给前端,所以不需要验证
errors.append(f"Missing key: {path}{key}")
continue
value = data[key]
if meta["type"] == "int" and not isinstance(value, int):
errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}")
elif meta["type"] == "bool" and not isinstance(value, bool):
errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}")
elif meta["type"] == "string" and not isinstance(value, str):
errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}")
elif meta["type"] == "list" and not isinstance(value, list):
errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}")
for item in value:
validate(item, meta["items"], path=f"{path}{key}.")
elif meta["type"] == "dict" and not isinstance(value, dict):
errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
validate(data)
# hardcode warning
data['config_version'] = self.context.config_helper.config_version
data['dashboard'] = asdict(self.context.config_helper.dashboard)
return errors
path = config['path'].split('.')
if len(path) == 0:
continue
if config['val_type'] == "bool":
self._write_config(
namespace, config['path'], config['value'])
elif config['val_type'] == "str":
self._write_config(
namespace, config['path'], config['value'])
elif config['val_type'] == "int":
try:
self._write_config(
namespace, config['path'], int(config['value']))
except:
raise ValueError(f"配置项 {config['name']} 的值必须是整数")
elif config['val_type'] == "float":
try:
self._write_config(
namespace, config['path'], float(config['value']))
except:
raise ValueError(f"配置项 {config['name']} 的值必须是浮点数")
elif config['val_type'] == "list":
if config['value'] is None:
self._write_config(namespace, config['path'], [])
elif not isinstance(config['value'], list):
raise ValueError(f"配置项 {config['name']} 的值必须是列表")
self._write_config(
namespace, config['path'], config['value'])
else:
raise NotImplementedError(
f"未知或者未实现的配置项类型:{config['val_type']}")
def _write_config(self, namespace: str, key: str, value):
if namespace == "" or namespace.startswith("internal_"):
# 机器人自带配置,存到 config.yaml
self.context.config_helper.put_by_dot_str(key, value)
else:
update_config(namespace, key, value)
def save_astrbot_config(self, post_config: dict):
'''验证并保存配置'''
errors = self.validate_config(post_config)
if errors:
raise ValueError(f"格式校验未通过: {errors}")
self.context.config_helper.flush_config(post_config)
def save_extension_config(self, post_config: list, namespace: str):
pass
# update_config(namespace, key, value)
+74 -66
View File
@@ -19,6 +19,7 @@ from dashboard.helper import DashBoardHelper
from util.io import get_local_ip_addresses
from model.plugin.manager import PluginManager
from util.updator.astrbot_updator import AstrBotUpdator
from type.config import CONFIG_METADATA_2
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -30,7 +31,7 @@ class AstrBotDashBoard():
self.plugin_manager = plugin_manager
self.astrbot_updator = astrbot_updator
self.dashboard_data = DashBoardData()
self.dashboard_helper = DashBoardHelper(self.context, self.dashboard_data)
self.dashboard_helper = DashBoardHelper(self.context)
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
logging.getLogger('werkzeug').setLevel(logging.ERROR)
@@ -68,8 +69,8 @@ class AstrBotDashBoard():
@self.dashboard_be.post("/api/authenticate")
def authenticate():
username = self.context.base_config.get("dashboard_username", "")
password = self.context.base_config.get("dashboard_password", "")
username = self.context.config_helper.dashboard.username
password = self.context.config_helper.dashboard.password
# 获得请求体
post_data = request.json
if post_data["username"] == username and post_data["password"] == password:
@@ -90,7 +91,7 @@ class AstrBotDashBoard():
@self.dashboard_be.post("/api/change_password")
def change_password():
password = self.context.base_config.get("dashboard_password", "")
password = self.context.config_helper.dashboard.password
# 获得请求体
post_data = request.json
if post_data["password"] == password:
@@ -130,40 +131,53 @@ class AstrBotDashBoard():
@self.dashboard_be.get("/api/configs")
def get_configs():
# 如果params中有namespace,则返回该namespace下的配置
# 否则返回所有配置
# namespace 为空时返回 AstrBot 配置
# 否则返回指定 namespace 的插件配置
namespace = "" if "namespace" not in request.args else request.args["namespace"]
conf = self._get_configs(namespace)
return Response(
status="success",
message="",
data=conf
).__dict__
@self.dashboard_be.get("/api/config_outline")
def get_config_outline():
outline = self._generate_outline()
return Response(
status="success",
message="",
data=outline
).__dict__
@self.dashboard_be.post("/api/configs")
def post_configs():
post_configs = request.json
try:
self.on_post_configs(post_configs)
if not namespace:
return Response(
status="success",
message="保存成功~ 机器人将在 2 秒内重启以应用新的配置。",
message="",
data=self._get_astrbot_config()
).__dict__
return Response(
status="success",
message="",
data=self._get_extension_config(namespace)
).__dict__
@self.dashboard_be.post("/api/astrbot-configs")
def post_astrbot_configs():
post_configs = request.json
try:
self.save_astrbot_configs(post_configs)
return Response(
status="success",
message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。",
data=None
).__dict__
except Exception as e:
return Response(
status="error",
message=e.__str__(),
data=self.dashboard_data.configs
data=None
).__dict__
@self.dashboard_be.post("/api/extension-configs")
def post_extension_configs():
post_configs = request.json
try:
self.save_extension_configs(post_configs)
return Response(
status="success",
message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。",
data=None
).__dict__
except Exception as e:
return Response(
status="error",
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.get("/api/extensions")
@@ -363,47 +377,41 @@ class AstrBotDashBoard():
data=None
).__dict__
def on_post_configs(self, post_configs: dict):
def save_astrbot_configs(self, post_configs: dict):
try:
if 'base_config' in post_configs:
self.dashboard_helper.save_config(
post_configs['base_config'], namespace='') # 基础配置
self.dashboard_helper.save_config(
post_configs['config'], namespace=post_configs['namespace']) # 选定配置
self.dashboard_helper.parse_default_config(
self.dashboard_data, self.context.config_helper.get_all())
# 重启
threading.Thread(target=self.astrbot_updator._reboot,
args=(2, ), daemon=True).start()
self.dashboard_helper.save_astrbot_config(post_configs)
threading.Thread(target=self.astrbot_updator._reboot, args=(3, ), daemon=True).start()
except Exception as e:
raise e
def save_extension_configs(self, post_configs: dict):
try:
self.dashboard_helper.save_extension_config(post_configs)
threading.Thread(target=self.astrbot_updator._reboot, args=(3, ), daemon=True).start()
except Exception as e:
raise e
def _get_astrbot_config(self):
config = self.context.config_helper.to_dict()
for key in self.dashboard_helper.config_key_dont_show:
if key in config:
del config[key]
return {
"metadata": CONFIG_METADATA_2,
"config": config,
}
def _get_configs(self, namespace: str):
if namespace == "":
ret = [self.dashboard_data.configs['data'][4],
self.dashboard_data.configs['data'][5],]
elif namespace == "internal_platform_qq_official":
ret = [self.dashboard_data.configs['data'][0],]
elif namespace == "internal_platform_qq_gocq":
ret = [self.dashboard_data.configs['data'][1],]
elif namespace == "internal_platform_general": # 全局平台配置
ret = [self.dashboard_data.configs['data'][2],]
elif namespace == "internal_llm_openai_official":
ret = [self.dashboard_data.configs['data'][3],]
elif namespace == "internal_platform_qq_aiocqhttp":
ret = [self.dashboard_data.configs['data'][6],]
else:
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
ret = [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
return ret
def _get_extension_config(self, namespace: str):
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
return [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
def _generate_outline(self):
'''
+5 -6
View File
@@ -62,12 +62,11 @@ class InternalCommandHandler:
return CommandResult().message("你没有权限使用该指令。")
l = message_str.split(" ")
if len(l) == 1:
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}")
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.config_helper.wake_prefix}")
nick = l[1].strip()
if not nick:
return CommandResult().message("wake: 请指定唤醒词。")
context.config_helper.put("nick_qq", nick)
context.nick = tuple(nick)
context.config_helper.wake_prefix = [nick]
return CommandResult(
hit=True,
success=True,
@@ -232,17 +231,17 @@ class InternalCommandHandler:
)
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
p = context.t2i_mode
p = context.config_helper.t2i
if p:
context.config_helper.put("qq_pic_mode", False)
context.t2i_mode = False
context.config_helper.t2i = False
return CommandResult(
hit=True,
success=True,
message_chain="已关闭文本转图片模式。",
)
context.config_helper.put("qq_pic_mode", True)
context.t2i_mode = True
context.config_helper.t2i = True
return CommandResult(
hit=True,
+5 -4
View File
@@ -52,10 +52,11 @@ class Platform():
return ret[:100] if len(ret) > 100 else ret
def check_nick(self, message_str: str) -> bool:
if self.context.nick:
for nick in self.context.nick:
if nick and message_str.strip().startswith(nick):
return True
w = self.context.config_helper.wake_prefix
if not w: return False
for nick in w:
if nick and message_str.strip().startswith(nick):
return True
return False
async def convert_to_t2i_chain(self, message_result: list) -> list:
+35 -25
View File
@@ -6,6 +6,12 @@ from type.types import Context
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import (
PlatformConfig,
AiocqhttpPlatformConfig,
NakuruPlatformConfig,
QQOfficialPlatformConfig
)
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -13,36 +19,40 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class PlatformManager():
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
self.context = context
self.config = context.base_config
self.msg_handler = message_handler
def load_platforms(self):
tasks = []
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
logger.info("启用 QQ(nakuru 适配器)")
tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter"))
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
logger.info("启用 QQ(aiocqhttp 适配器)")
tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter"))
platforms = self.context.config_helper.platform
logger.info(f"加载 {len(platforms)} 个机器人消息平台...")
for platform in platforms:
if not platform.enable:
continue
if platform.name == "qq_official":
assert isinstance(platform, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
logger.info(f"加载 QQ官方 机器人消息平台 (appid: {platform.appid})")
tasks.append(asyncio.create_task(self.qqofficial_bot(platform), name="qqofficial-adapter"))
elif platform.name == "nakuru":
assert isinstance(platform, NakuruPlatformConfig), "nakuru: 无法识别的配置类型。"
logger.info(f"加载 QQ(nakuru) 机器人消息平台 ({platform.host}, {platform.websocket_port}, {platform.port})")
tasks.append(asyncio.create_task(self.nakuru_bot(platform), name="nakuru-adapter"))
elif platform.name == "aiocqhttp":
assert isinstance(platform, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
logger.info("加载 QQ(aiocqhttp) 机器人消息平台")
tasks.append(asyncio.create_task(self.aiocq_bot(platform), name="aiocqhttp-adapter"))
# QQ频道
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
logger.info("启用 QQ(官方 API) 机器人消息平台")
tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter"))
return tasks
async def gocq_bot(self):
async def nakuru_bot(self, config: NakuruPlatformConfig):
'''
运行 QQ(nakuru 适配器)
'''
from model.platform.qq_nakuru import QQGOCQ
from model.platform.qq_nakuru import QQNakuru
noticed = False
host = self.config.get("gocq_host", "127.0.0.1")
port = self.config.get("gocq_websocket_port", 6700)
http_port = self.config.get("gocq_http_port", 5700)
host = config.host
port = config.websocket_port
http_port = config.port
logger.info(
f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}")
while True:
@@ -56,32 +66,32 @@ class PlatformManager():
logger.info("nakuru 适配器已连接。")
break
try:
qq_gocq = QQGOCQ(self.context, self.msg_handler)
qq_gocq = QQNakuru(self.context, self.msg_handler, config)
self.context.platforms.append(RegisteredPlatform(
platform_name="gocq", platform_instance=qq_gocq, origin="internal"))
platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
await qq_gocq.run()
except BaseException as e:
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
def aiocq_bot(self):
def aiocq_bot(self, config):
'''
运行 QQ(aiocqhttp 适配器)
'''
from model.platform.qq_aiocqhttp import AIOCQHTTP
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler)
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler, config)
self.context.platforms.append(RegisteredPlatform(
platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal"))
return qq_aiocqhttp.run_aiocqhttp()
def qqchan_bot(self):
def qqofficial_bot(self, config):
'''
运行 QQ 官方机器人适配器
'''
try:
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(self.context, self.msg_handler)
qqchannel_bot = QQOfficial(self.context, self.msg_handler, config)
self.context.platforms.append(RegisteredPlatform(
platform_name="qqchan", platform_instance=qqchannel_bot, origin="internal"))
platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
return qqchannel_bot.run()
except BaseException as e:
logger.error("启动 QQ官方机器人适配器时出现错误: " + str(e))
+15 -9
View File
@@ -13,18 +13,25 @@ from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AIOCQHTTP(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
def __init__(self, context: Context,
message_handler: MessageHandler,
platform_config: PlatformConfig) -> None:
assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
self.message_handler = message_handler
self.waiting = {}
self.context = context
self.unique_session = self.context.unique_session
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
self.host = self.context.base_config['aiocqhttp']['ws_reverse_host']
self.port = self.context.base_config['aiocqhttp']['ws_reverse_port']
self.config = platform_config
self.unique_session = context.config_helper.platform_settings.unique_session
self.announcement = context.config_helper.platform_settings.welcome_message_when_join
self.host = platform_config.ws_reverse_host
self.port = platform_config.ws_reverse_port
self.admins = context.config_helper.admins_id
def convert_message(self, event: Event) -> AstrBotMessage:
@@ -123,12 +130,11 @@ class AIOCQHTTP(Platform):
# 解析 role
sender_id = str(message.sender.user_id)
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
# construct astrbot message event
ame = AstrMessageEvent.from_astrbot_message(message, self.context, "aiocqhttp", message.session_id, role)
@@ -160,7 +166,7 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
if self.context.config_helper.t2i and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
+22 -16
View File
@@ -18,6 +18,7 @@ from type.command import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -28,46 +29,52 @@ class FakeSource:
self.group_id = group_id
class QQGOCQ(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
class QQNakuru(Platform):
def __init__(self, context: Context,
message_handler: MessageHandler,
platform_config: PlatformConfig) -> None:
assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。"
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.message_handler = message_handler
self.waiting = {}
self.context = context
self.unique_session = self.context.unique_session
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
self.unique_session = context.config_helper.platform_settings.unique_session
self.announcement = context.config_helper.platform_settings.welcome_message_when_join
self.config = platform_config
self.admins = context.config_helper.admins_id
self.client = CQHTTP(
host=self.context.base_config.get("gocq_host", "127.0.0.1"),
port=self.context.base_config.get("gocq_websocket_port", 6700),
http_port=self.context.base_config.get("gocq_http_port", 5700),
host=self.config.host,
port=self.config.websocket_port,
http_port=self.config.port
)
gocq_app = self.client
@gocq_app.receiver("GroupMessage")
async def _(app: CQHTTP, source: GroupMessage):
if self.context.base_config.get("gocq_react_group", True):
if self.config.enable_group:
abm = self.convert_message(source)
await self.handle_msg(abm)
@gocq_app.receiver("FriendMessage")
async def _(app: CQHTTP, source: FriendMessage):
if self.context.base_config.get("gocq_react_friend", True):
if self.config.enable_direct_message:
abm = self.convert_message(source)
await self.handle_msg(abm)
@gocq_app.receiver("GroupMemberIncrease")
async def _(app: CQHTTP, source: GroupMemberIncrease):
if self.context.base_config.get("gocq_react_group_increase", True):
if self.config.enable_group_increase:
await app.sendGroupMessage(source.group_id, [
Plain(text=self.announcement)
])
@gocq_app.receiver("GuildMessage")
async def _(app: CQHTTP, source: GuildMessage):
if self.cc.get("gocq_react_guild", True):
if self.config.enable_guild:
abm = self.convert_message(source)
await self.handle_msg(abm)
@@ -112,8 +119,7 @@ class QQGOCQ(Platform):
# 解析 role
sender_id = str(message.raw_message.user_id)
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
@@ -152,7 +158,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if self.context.base_config.get("qq_pic_mode", False) and isinstance(res, list):
if self.context.config_helper.t2i and isinstance(res, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -186,7 +192,7 @@ class QQGOCQ(Platform):
plain_text_len += len(i.text)
elif isinstance(i, Image):
image_num += 1
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
if plain_text_len > self.context.config_helper.platform_settings.forward_threshold or image_num > 1:
# 删除At
for i in message_chain:
if isinstance(i, At):
+22 -31
View File
@@ -19,6 +19,7 @@ from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -52,32 +53,36 @@ class botClient(Client):
class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler, test_mode = False) -> None:
super().__init__()
def __init__(self, context: Context,
message_handler: MessageHandler,
platform_config: PlatformConfig,
test_mode = False) -> None:
assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.message_handler = message_handler
self.waiting: dict = {}
self.context = context
self.config = platform_config
self.admins = context.config_helper.admins_id
self.appid = context.base_config['qqbot']['appid']
self.token = context.base_config['qqbot']['token']
self.secret = context.base_config['qqbot_secret']
self.unique_session = context.unique_session
qq_group = context.base_config['qqofficial_enable_group_message']
self.appid = platform_config.appid
self.secret = platform_config.secret
self.unique_session = context.config_helper.platform_settings.unique_session
qq_group = platform_config.enable_group_c2c
guild_dm = platform_config.enable_guild_direct_message
if qq_group:
self.intents = botpy.Intents(
public_messages=True,
public_guild_messages=True,
direct_message=context.base_config['direct_message_mode']
direct_message=guild_dm
)
else:
self.intents = botpy.Intents(
public_guild_messages=True,
direct_message=context.base_config['direct_message_mode']
direct_message=guild_dm
)
self.client = botClient(
intents=self.intents,
@@ -168,24 +173,11 @@ class QQOfficial(Platform):
return abm
def run(self):
try:
return self.client.start(
appid=self.appid,
secret=self.secret
)
except BaseException as e:
# 早期的 qq-botpy 版本使用 token 登录。
logger.error(traceback.format_exc())
self.client = botClient(
intents=self.intents,
bot_log=False
)
self.client.set_platform(self)
return self.client.start(
appid=self.appid,
token=self.token
)
return self.client.start(
appid=self.appid,
secret=self.secret
)
async def handle_msg(self, message: AstrBotMessage):
assert isinstance(message.raw_message, (botpy.message.Message,
botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
@@ -209,8 +201,7 @@ class QQOfficial(Platform):
# 解析出 role
sender_id = message.sender.user_id
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
sender_id in self.context.base_config.get('other_admins', None):
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
@@ -249,7 +240,7 @@ class QQOfficial(Platform):
msg_ref = None
rendered_images = []
if self.context.base_config.get("qq_pic_mode", False) and isinstance(result_message, list):
if self.context.config_helper.t2i and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list):
+22 -29
View File
@@ -1,5 +1,3 @@
import os
import sys
import json
import time
import tiktoken
@@ -15,12 +13,12 @@ from openai._exceptions import *
from astrbot.persist.helper import dbConn
from model.provider.provider import Provider
from util import general_utils as gu
from util.cmd_config import CmdConfig
from util.cmd_config import LLMConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
from type.types import Context
from dataclasses import asdict
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -48,22 +46,16 @@ MODELS = {
}
class ProviderOpenAIOfficial(Provider):
def __init__(self, context: Context) -> None:
def __init__(self, llm_config: LLMConfig) -> None:
super().__init__()
os.makedirs("data/openai", exist_ok=True)
self.context = context
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.llm_config = llm_config
self.keys_data = {} # 记录超额
cfg = context.base_config['openai']
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if llm_config.key: self.api_keys = llm_config.key
if llm_config.api_base: self.base_url = llm_config.api_base
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
else:
@@ -76,18 +68,20 @@ class ProviderOpenAIOfficial(Provider):
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.model_configs: Dict = cfg['chatGPTConfigs']
super().set_curr_model(self.model_configs['model'])
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
super().set_curr_model(llm_config.model_config.model)
self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小
logger.info("正在载入分词器 cl100k_base...")
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
logger.info("分词器载入完成。")
self.DEFAULT_PERSONALITY = context.default_personality
self.DEFAULT_PERSONALITY = {
"prompt": self.llm_config.default_personality,
"name": "default"
}
self.curr_personality = self.DEFAULT_PERSONALITY
self.session_personality = {} # 记录了某个session是否已设置人格。
# 从 SQLite DB 读取历史记录
@@ -134,7 +128,7 @@ class ProviderOpenAIOfficial(Provider):
self.session_personality = {} # 重置
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
model = self.get_curr_model()
if model in MODELS and tokens_num > MODELS[model] - 500:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
@@ -173,7 +167,7 @@ class ProviderOpenAIOfficial(Provider):
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
logger.warn(f"由于当前模型 {self.get_curr_model()} 不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
continue
context.append(record['user'])
if "AI" in record and record['AI']:
@@ -185,7 +179,7 @@ class ProviderOpenAIOfficial(Provider):
'''
是否是 LVM
'''
return self.model_configs['model'].startswith("gpt-4")
return self.get_curr_model().startswith("gpt-4")
async def get_models(self):
try:
@@ -238,7 +232,7 @@ class ProviderOpenAIOfficial(Provider):
self.session_memory[session_id].append(message)
# 根据 模型的上下文窗口 淘汰掉多余的记录
curr_model = self.model_configs['model']
curr_model = self.get_curr_model()
if curr_model in MODELS:
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
# if message['usage_tokens'] > maxium_tokens_num:
@@ -314,7 +308,7 @@ class ProviderOpenAIOfficial(Provider):
# 1. 可以保证之后 pop 的时候不会出现问题
# 2. 可以保证不会超过最大 token 数
_encoded_prompt = self.tokenizer.encode(prompt)
curr_model = self.model_configs['model']
curr_model = self.get_curr_model()
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
prompt = self.tokenizer.decode(_encoded_prompt)
@@ -325,7 +319,7 @@ class ProviderOpenAIOfficial(Provider):
# 获取上下文,openai 格式
contexts = await self.retrieve_context(session_id)
conf = self.model_configs
conf = asdict(self.llm_config.model_config)
if extra_conf: conf.update(extra_conf)
# start request
@@ -358,7 +352,7 @@ class ProviderOpenAIOfficial(Provider):
except BadRequestError as e:
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
raise Exception(f"当前模型 { self.get_curr_model() } 不支持图片输入,请更换模型。")
retry += 1
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
@@ -493,12 +487,11 @@ class ProviderOpenAIOfficial(Provider):
return contexts_str, len(self.session_memory[session_id])
def set_model(self, model: str):
self.model_configs['model'] = model
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
# TODO: 更新配置文件
super().set_curr_model(model)
def get_configs(self):
return self.model_configs
return asdict(self.llm_config)
def get_keys_data(self):
return self.keys_data
+277
View File
@@ -72,4 +72,281 @@ DEFAULT_CONFIG = {
"ws_reverse_host": "",
"ws_reverse_port": 0,
}
}
# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D
DEFAULT_CONFIG_VERSION_2 = {
"config_version": 2,
"platform": [
{
"name": "qq_official",
"enable": False,
"appid": "",
"secret": "",
"enable_group_c2c": True,
"enable_guild_direct_message": True,
},
{
"name": "nakuru",
"enable": False,
"host": "172.0.0.1",
"port": 5700,
"websocket_port": 6700,
"enable_group": True,
"enable_guild": True,
"enable_direct_message": True,
"enable_group_increase": True,
},
{
"name": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 6199,
}
],
"platform_settings": {
"unique_session": False,
"welcome_message_when_join": "",
"rate_limit": {
"time": 60,
"count": 30,
},
"reply_prefix": "",
"forward_threshold": 200, # 转发消息的阈值
},
"llm": [
{
"name": "openai",
"enable": True,
"key": [],
"api_base": "",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"image_generation_model_config": {
"enable": True,
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
}
},
],
"llm_settings": {
"wake_prefix": "",
"web_search": False,
},
"content_safety": {
"baidu_aip": {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": "",
},
"internal_keywords": {
"enable": True,
"extra_keywords": [],
}
},
"wake_prefix": [],
"t2i": True,
"dump_history_interval": 10,
"admins_id": [],
"https_proxy": "",
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "",
"password": "",
},
}
# 这个是用于迁移旧版本配置文件的映射表
MAPPINGS_1_2 = [
[["qqbot", "enable"], ["platform", 0, "enable"]],
[["qqbot", "appid"], ["platform", 0, "appid"]],
[["qqbot", "token"], ["platform", 0, "secret"]],
[["qqofficial_enable_group_message"], ["platform", 0, "enable_group_c2c"]],
[["direct_message_mode"], ["platform", 0, "enable_guild_direct_message"]],
[["gocqbot", "enable"], ["platform", 1, "enable"]],
[["gocq_host"], ["platform", 1, "host"]],
[["gocq_http_port"], ["platform", 1, "port"]],
[["gocq_websocket_port"], ["platform", 1, "websocket_port"]],
[["gocq_react_group"], ["platform", 1, "enable_group"]],
[["gocq_react_guild"], ["platform", 1, "enable_guild"]],
[["gocq_react_friend"], ["platform", 1, "enable_direct_message"]],
[["gocq_react_group_increase"], ["platform", 1, "enable_group_increase"]],
[["aiocqhttp", "enable"], ["platform", 2, "enable"]],
[["aiocqhttp", "ws_reverse_host"], ["platform", 2, "ws_reverse_host"]],
[["aiocqhttp", "ws_reverse_port"], ["platform", 2, "ws_reverse_port"]],
[["uniqueSessionMode"], ["platform_settings", "unique_session"]],
[["qq_welcome"], ["platform_settings", "welcome_message_when_join"]],
[["limit", "time"], ["platform_settings", "rate_limit", "time"]],
[["limit", "count"], ["platform_settings", "rate_limit", "count"]],
[["reply_prefix"], ["platform_settings", "reply_prefix"]],
[["qq_forward_threshold"], ["platform_settings", "forward_threshold"]],
[["openai", "key"], ["llm", 0, "key"]],
[["openai", "api_base"], ["llm", 0, "api_base"]],
[["openai", "chatGPTConfigs", "model"], ["llm", 0, "model_config", "model"]],
[["openai", "chatGPTConfigs", "max_tokens"], ["llm", 0, "model_config", "max_tokens"]],
[["openai", "chatGPTConfigs", "temperature"], ["llm", 0, "model_config", "temperature"]],
[["openai", "chatGPTConfigs", "top_p"], ["llm", 0, "model_config", "top_p"]],
[["openai", "chatGPTConfigs", "frequency_penalty"], ["llm", 0, "model_config", "frequency_penalty"]],
[["openai", "chatGPTConfigs", "presence_penalty"], ["llm", 0, "model_config", "presence_penalty"]],
[["default_personality_str"], ["llm", 0, "default_personality"]],
[["llm_env_prompt"], ["llm", 0, "prompt_prefix"]],
[["openai_image_generate", "model"], ["llm", 0, "image_generation_model_config", "model"]],
[["openai_image_generate", "size"], ["llm", 0, "image_generation_model_config", "size"]],
[["openai_image_generate", "style"], ["llm", 0, "image_generation_model_config", "style"]],
[["openai_image_generate", "quality"], ["llm", 0, "image_generation_model_config", "quality"]],
[["llm_wake_prefix"], ["llm_settings", "wake_prefix"]],
[["baidu_aip", "enable"], ["content_safety", "baidu_aip", "enable"]],
[["baidu_aip", "app_id"], ["content_safety", "baidu_aip", "app_id"]],
[["baidu_aip", "api_key"], ["content_safety", "baidu_aip", "api_key"]],
[["baidu_aip", "secret_key"], ["content_safety", "baidu_aip", "secret_key"]],
[["qq_pic_mode"], ["t2i"]],
[["dump_history_interval"], ["dump_history_interval"]],
[["other_admins"], ["admins_id"]],
[["http_proxy"], ["http_proxy"]],
[["https_proxy"], ["https_proxy"]],
[["dashboard_username"], ["dashboard", "username"]],
[["dashboard_password"], ["dashboard", "password"]],
[["nick_qq"], ["wake_prefix"]],
]
CONFIG_METADATA_2 = {
"config_version": {"description": "配置版本", "type": "int"},
"platform": {
"description": "平台配置",
"type": "list",
"items": {
"name": {"description": "平台名称", "type": "string"},
"enable": {"description": "是否启用", "type": "bool"},
"appid": {"description": "应用ID", "type": "string"},
"secret": {"description": "应用密钥", "type": "string"},
"enable_group_c2c": {"description": "启用群C2C", "type": "bool"},
"enable_guild_direct_message": {"description": "启用公会直接消息", "type": "bool"},
"host": {"description": "主机地址", "type": "string"},
"port": {"description": "端口", "type": "int"},
"websocket_port": {"description": "WebSocket端口", "type": "int"},
"ws_reverse_host": {"description": "WebSocket反向主机", "type": "string"},
"ws_reverse_port": {"description": "WebSocket反向端口", "type": "int"},
"enable_group": {"description": "启用群组", "type": "bool"},
"enable_guild": {"description": "启用公会", "type": "bool"},
"enable_direct_message": {"description": "启用直接消息", "type": "bool"},
"enable_group_increase": {"description": "启用群组增加", "type": "bool"},
}
},
"platform_settings": {
"description": "平台设置",
"type": "object",
"items": {
"unique_session": {"description": "唯一会话", "type": "bool"},
"welcome_message_when_join": {"description": "加入时欢迎信息", "type": "string"},
"rate_limit": {
"description": "速率限制",
"type": "object",
"items": {
"time": {"description": "时间", "type": "int"},
"count": {"description": "计数", "type": "int"},
}
},
"reply_prefix": {"description": "回复前缀", "type": "string"},
"forward_threshold": {"description": "转发消息的阈值", "type": "int"},
}
},
"llm": {
"description": "大语言模型配置",
"type": "list",
"items": {
"name": {"description": "模型名称", "type": "string"},
"enable": {"description": "是否启用", "type": "bool"},
"key": {"description": "密钥", "type": "list", "items": {"type": "string"}},
"api_base": {"description": "API基础URL", "type": "string"},
"prompt_prefix": {"description": "提示前缀", "type": "string"},
"default_personality": {"description": "默认个性", "type": "string"},
"model_config": {
"description": "模型配置",
"type": "object",
"items": {
"model": {"description": "模型名称", "type": "string"},
"max_tokens": {"description": "最大令牌数", "type": "int"},
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
"frequency_penalty": {"description": "频率惩罚", "type": "float"},
"presence_penalty": {"description": "存在惩罚", "type": "float"},
}
},
"image_generation_model_config": {
"description": "图像生成模型配置",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"model": {"description": "模型名称", "type": "string"},
"size": {"description": "图像尺寸", "type": "string"},
"style": {"description": "图像风格", "type": "string"},
"quality": {"description": "图像质量", "type": "string"},
}
},
}
},
"llm_settings": {
"description": "大语言模型设置",
"type": "object",
"items": {
"wake_prefix": {"description": "唤醒前缀", "type": "string"},
"web_search": {"description": "启用网络搜索", "type": "bool"},
}
},
"content_safety": {
"description": "内容安全配置",
"type": "object",
"items": {
"baidu_aip": {
"description": "百度AI平台配置",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"app_id": {"description": "应用ID", "type": "string"},
"api_key": {"description": "API密钥", "type": "string"},
"secret_key": {"description": "秘密密钥", "type": "string"},
}
},
"internal_keywords": {
"description": "内部关键词过滤",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"extra_keywords": {"description": "额外关键词", "type": "list", "items": {"type": "string"}},
}
}
}
},
"wake_prefix": {"description": "唤醒前缀列表", "type": "list", "items": {"type": "string"}},
"t2i": {"description": "文本转图像功能", "type": "bool"},
"dump_history_interval": {"description": "历史记录转储间隔", "type": "int"},
"admins_id": {"description": "管理员ID列表", "type": "list", "items": {"type": "int"}},
"https_proxy": {"description": "HTTPS代理", "type": "string"},
"http_proxy": {"description": "HTTP代理", "type": "string"},
"dashboard": {
"description": "仪表盘配置",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"username": {"description": "用户名", "type": "string"},
"password": {"description": "密码", "type": "string"},
}
},
}
+8 -8
View File
@@ -3,7 +3,7 @@ from asyncio import Task
from type.register import *
from typing import List, Awaitable
from logging import Logger
from util.cmd_config import CmdConfig
from util.cmd_config import AstrBotConfig
from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
@@ -20,17 +20,17 @@ class Context:
def __init__(self):
self.logger: Logger = None
self.base_config: dict = None # 配置(期望启动机器人后是不变的)
self.config_helper: CmdConfig = None
self.config_helper: AstrBotConfig = None
self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件
self.platforms: List[RegisteredPlatform] = []
self.llms: List[RegisteredLLM] = []
self.default_personality: dict = None
self.unique_session = False # 独立会话
self.version: str = None # 机器人版本
self.nick: tuple = None # gocq 的唤醒词
self.t2i_mode = False
self.web_search = False # 是否开启了网页搜索
# self.unique_session = False # 独立会话
# self.version: str = None # 机器人版本
# self.nick: tuple = None # gocq 的唤醒词
# self.t2i_mode = False
# self.web_search = False # 是否开启了网页搜索
self.metrics_uploader = None
self.updator: AstrBotUpdator = None
@@ -42,7 +42,7 @@ class Context:
self.ext_tasks: List[Task] = []
# useless
self.reply_prefix = ""
# self.reply_prefix = ""
def register_commands(self,
plugin_name: str,
+242 -45
View File
@@ -1,33 +1,236 @@
import os
import json
from type.config import DEFAULT_CONFIG
import logging
from type.config import DEFAULT_CONFIG, DEFAULT_CONFIG_VERSION_2, MAPPINGS_1_2
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional
cpath = "data/cmd_config.json"
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
logger = logging.getLogger("astrbot")
def check_exist():
if not os.path.exists(cpath):
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump({}, f, ensure_ascii=False)
f.flush()
@dataclass
class RateLimit:
time: int = 60
count: int = 30
@dataclass
class PlatformSettings:
unique_session: bool = False
welcome_message_when_join: str = ""
rate_limit: RateLimit = field(default_factory=RateLimit)
reply_prefix: str = ""
forward_threshold: int = 200
def __post_init__(self):
self.rate_limit = RateLimit(**self.rate_limit)
@dataclass
class PlatformConfig():
name: str = ""
enable: bool = False
@dataclass
class QQOfficialPlatformConfig(PlatformConfig):
appid: str = ""
secret: str = ""
enable_group_c2c: bool = True
enable_guild_direct_message: bool = True
@dataclass
class NakuruPlatformConfig(PlatformConfig):
host: str = "172.0.0.1",
port: int = 5700,
websocket_port: int = 6700,
enable_group: bool = True,
enable_guild: bool = True,
enable_direct_message: bool = True,
enable_group_increase: bool = True
@dataclass
class AiocqhttpPlatformConfig(PlatformConfig):
ws_reverse_host: str = ""
ws_reverse_port: int = 6199
@dataclass
class ModelConfig:
model: str = "gpt-4o"
max_tokens: int = 6000
temperature: float = 0.9
top_p: float = 1
frequency_penalty: float = 0
presence_penalty: float = 0
@dataclass
class ImageGenerationModelConfig:
enable: bool = True
model: str = "dall-e-3"
size: str = "1024x1024"
style: str = "vivid"
quality: str = "standard"
@dataclass
class LLMConfig:
name: str = "openai"
enable: bool = True
key: List[str] = field(default_factory=list)
api_base: str = ""
prompt_prefix: str = ""
default_personality: str = ""
model_config: ModelConfig = field(default_factory=ModelConfig)
image_generation_model_config: ImageGenerationModelConfig = field(default_factory=ImageGenerationModelConfig)
def __post_init__(self):
self.model_config = ModelConfig(**self.model_config)
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config)
@dataclass
class LLMSettings:
wake_prefix: str = ""
web_search: bool = False
@dataclass
class BaiduAIPConfig:
enable: bool = False
app_id: str = ""
api_key: str = ""
secret_key: str = ""
@dataclass
class InternalKeywordsConfig:
enable: bool = True
extra_keywords: List[str] = field(default_factory=list)
@dataclass
class ContentSafetyConfig:
baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig)
internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig)
def __post_init__(self):
self.baidu_aip = BaiduAIPConfig(**self.baidu_aip)
self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords)
@dataclass
class DashboardConfig:
enable: bool = True
username: str = ""
password: str = ""
@dataclass
class AstrBotConfig():
config_version: int = 2
platform_settings: PlatformSettings = field(default_factory=PlatformSettings)
llm: List[LLMConfig] = field(default_factory=list)
llm_settings: LLMSettings = field(default_factory=LLMSettings)
content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig)
t2i: bool = True
dump_history_interval: int = 10
admins_id: List[str] = field(default_factory=list)
https_proxy: str = ""
http_proxy: str = ""
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
platform: List[PlatformConfig] = field(default_factory=list)
wake_prefix: List[str] = field(default_factory=list)
class CmdConfig():
def __init__(self) -> None:
self.cached_config: dict = {}
self.init_configs()
# compability
if isinstance(self.wake_prefix, str):
self.wake_prefix = [self.wake_prefix]
def load_from_dict(self, data: Dict):
'''从字典中加载配置到对象。
@note: 适用于 version 2 配置文件。
'''
self.config_version=data.get("version", 2)
self.platform=[]
for p in data.get("platform", []):
if 'name' not in p:
logger.warning("A platform config missing name, skipping.")
continue
if p["name"] == "qq_official":
self.platform.append(QQOfficialPlatformConfig(**p))
elif p["name"] == "nakuru":
self.platform.append(NakuruPlatformConfig(**p))
elif p["name"] == "aiocqhttp":
self.platform.append(AiocqhttpPlatformConfig(**p))
else:
self.platform.append(PlatformConfig(**p))
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
self.content_safety=ContentSafetyConfig(**data.get("content_safety", {}))
self.t2i=data.get("t2i", True)
self.dump_history_interval=data.get("dump_history_interval", 10)
self.admins_id=data.get("admins_id", [])
self.https_proxy=data.get("https_proxy", "")
self.http_proxy=data.get("http_proxy", "")
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
self.wake_prefix=data.get("wake_prefix", [])
def migrate_config_1_2(self, old: dict) -> dict:
'''将配置文件从版本 1 迁移至版本 2'''
logger.info("正在更新配置文件到 version 2...")
new_config = DEFAULT_CONFIG_VERSION_2
mappings = MAPPINGS_1_2
def set_nested_value(d, keys, value):
cursor = d
for key in keys[:-1]:
cursor = cursor[key]
cursor[keys[-1]] = value
for old_path, new_path in mappings:
value = old
try:
for key in old_path:
value = value[key] # soooooo convenient!!
set_nested_value(new_config, new_path, value)
except KeyError:
# 如果旧配置中没有这个键,跳过,即使用新配置的默认值
continue
logger.info("配置文件更新完成。")
return new_config
def flush_config(self, config: dict = None):
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False)
f.flush()
def init_configs(self):
'''
初始化必需的配置项
'''
self.init_config_items(DEFAULT_CONFIG)
'''初始化必需的配置项'''
config = None
if not self.check_exist():
self.flush_config()
config = DEFAULT_CONFIG_VERSION_2
else:
config = self.get_all()
# check if the config is outdated
if 'config_version' not in config: # version 1
config = self.migrate_config_1_2(config)
self.flush_config(config)
_tag = False
for key, val in DEFAULT_CONFIG_VERSION_2.items():
if key not in config:
config[key] = val
_tag = True
if _tag:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
f.flush()
@staticmethod
def get(key, default=None):
self.load_from_dict(config)
def get(self, key, default=None):
'''
从文件系统中直接获取配置
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
if key in d:
return d[key]
@@ -38,8 +241,7 @@ class CmdConfig():
'''
从文件系统中获取所有配置
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
@@ -47,21 +249,19 @@ class CmdConfig():
return conf
def put(self, key, value):
with open(cpath, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
d[key] = value
with open(cpath, "w", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
self.cached_config[key] = value
@staticmethod
def put_by_dot_str(key: str, value):
'''
根据点分割的字符串,将值写入配置文件
'''
with open(cpath, "r", encoding="utf-8-sig") as f:
def to_dict(self) -> Dict:
return asdict(self)
def put_by_dot_str(self, key: str, value):
'''根据点分割的字符串,将值写入配置文件'''
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
_d = d
_ks = key.split(".")
@@ -70,23 +270,20 @@ class CmdConfig():
_d[_ks[i]] = value
else:
_d = _d[_ks[i]]
with open(cpath, "w", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
def init_config_items(self, d: dict):
conf = self.get_all()
def update_by_path(self, path: List):
'''根据路径更新配置文件。
if not self.cached_config:
self.cached_config = conf
这个方法首先会更新缓存在内存中的配置,然后再写入文件。
'''
for key in path:
if key not in self:
raise KeyError(f"Key {key} not found in config.")
_tag = False
for key, val in d.items():
if key not in conf:
conf[key] = val
_tag = True
if _tag:
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(conf, f, indent=2, ensure_ascii=False)
f.flush()
def check_exist(self) -> bool:
return os.path.exists(ASTRBOT_CONFIG_PATH)
+9 -10
View File
@@ -1,16 +1,15 @@
import json, os
from util.cmd_config import CmdConfig
import json, os, shutil
import logging
logger = logging.getLogger("astrbot")
def try_migrate_config():
'''
将 cmd_config.json 迁移至 data/cmd_config.json (如果存在的话)
'''
if os.path.exists("cmd_config.json"):
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
data = json.load(f)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
if os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"):
try:
os.remove("cmd_config.json")
except Exception as e:
pass
shutil.move("cmd_config.json", "data/cmd_config.json")
except:
logger.error("迁移 cmd_config.json 失败。AstrBot 将不会读取配置文件,你可以手动将 cmd_config.json 迁移至 data/cmd_config.json。")
+2 -1
View File
@@ -5,6 +5,7 @@ import sys
from type.types import Context
from collections import defaultdict
from type.config import VERSION
class MetricUploader():
def __init__(self, context: Context) -> None:
@@ -49,7 +50,7 @@ class MetricUploader():
try:
res = {
"stat_version": "moon",
"version": context.version, # 版本号
"version": VERSION, # 版本号
"platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数
"llm_stats": self.llm_stats,
"plugin_stats": self.plugin_stats,
+1 -1
View File
@@ -7,6 +7,6 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。
from model.platform import Platform
from model.platform.qq_nakuru import QQGOCQ
from model.platform.qq_nakuru import QQNakuru
from model.platform.qq_official import QQOfficial
from model.platform.qq_aiocqhttp import AIOCQHTTP