Merge pull request #184 from Soulter/config-refactor

更易读的配置格式和平台、LLM多实例
This commit is contained in:
Soulter
2024-09-10 11:01:42 +00:00
committed by GitHub
24 changed files with 989 additions and 928 deletions
+5 -1
View File
@@ -26,7 +26,11 @@ jobs:
mkdir temp
- name: Run tests
run: PYTHONPATH=./ pytest --cov=. tests/ -v
run: |
export LLM_MODEL=${{ secrets.LLM_MODEL }}
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
+2 -1
View File
@@ -10,4 +10,5 @@ cmd_config.json
data/*
cookies.json
logs/
addons/plugins
addons/plugins
.coverage
+24 -29
View File
@@ -13,7 +13,7 @@ from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from util.cmd_config import CmdConfig
from util.cmd_config import AstrBotConfig
from util.metrics import MetricUploader
from util.config_utils import *
from util.updator.astrbot_updator import AstrBotUpdator
@@ -24,29 +24,15 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AstrBotBootstrap():
def __init__(self) -> None:
self.context = Context()
self.config_helper = CmdConfig()
# load configs and ensure the backward compatibility
try_migrate_config()
self.config_helper = AstrBotConfig()
self.context.config_helper = self.config_helper
self.context.base_config = self.config_helper.cached_config
self.context.default_personality = {
"name": "default",
"prompt": self.context.base_config.get("default_personality_str", ""),
}
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
self.context.nick = nick_qq
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
self.context.version = VERSION
logger.info("AstrBot v" + self.context.version)
logger.info("AstrBot v" + VERSION)
# apply proxy settings
http_proxy = self.context.base_config.get("http_proxy")
https_proxy = self.context.base_config.get("https_proxy")
http_proxy = self.context.config_helper.http_proxy
https_proxy = self.context.config_helper.https_proxy
if http_proxy:
os.environ['HTTP_PROXY'] = http_proxy
if https_proxy:
@@ -68,10 +54,9 @@ class AstrBotBootstrap():
self.db_conn_helper = dbConn()
# load llm provider
self.llm_instance: Provider = None
self.load_llm()
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance)
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper)
self.platfrom_manager = PlatformManager(self.context, self.message_handler)
self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator)
self.metrics_uploader = MetricUploader(self.context)
@@ -114,16 +99,26 @@ class AstrBotBootstrap():
return
def load_llm(self):
if 'openai' in self.config_helper.cached_config and \
len(self.config_helper.cached_config['openai']['key']) and \
self.config_helper.cached_config['openai']['key'][0] is not None:
from model.provider.openai_official import ProviderOpenAIOfficial
f = False
llms = self.context.config_helper.llm
logger.info(f"加载 {len(llms)} 个 LLM Provider...")
for llm in llms:
if llm.enable:
if llm.name == "openai" and llm.key and llm.enable:
self.load_openai(llm)
f = True
logger.info(f"已启用 OpenAI API 支持。")
else:
logger.warn(f"未知的 LLM Provider: {llm.name}")
if f:
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
self.llm_instance = ProviderOpenAIOfficial(self.context)
self.openai_command_handler.set_provider(self.llm_instance)
self.context.register_provider("internal_openai", self.llm_instance)
logger.info("已启用 OpenAI API 支持。")
self.openai_command_handler.set_provider(self.context.llms[0].llm_instance)
def load_openai(self, llm_config):
from model.provider.openai_official import ProviderOpenAIOfficial
inst = ProviderOpenAIOfficial(llm_config)
self.context.register_provider("internal_openai", inst)
def load_plugins(self):
self.plugin_manager.plugin_reload()
+8 -9
View File
@@ -1,16 +1,15 @@
from aip import AipContentCensor
from util.cmd_config import BaiduAIPConfig
class BaiduJudge:
def __init__(self, baidu_configs) -> None:
if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs:
self.app_id = str(baidu_configs['app_id'])
self.api_key = baidu_configs['api_key']
self.secret_key = baidu_configs['secret_key']
self.client = AipContentCensor(
self.app_id, self.api_key, self.secret_key)
else:
raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!")
def __init__(self, baidu_configs: BaiduAIPConfig) -> None:
self.app_id = baidu_configs.app_id
self.api_key = baidu_configs.api_key
self.secret_key = baidu_configs.secret_key
self.client = AipContentCensor(self.app_id,
self.api_key,
self.secret_key)
def judge(self, text):
res = self.client.textCensorUserDefined(text)
+13 -28
View File
@@ -25,16 +25,11 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class RateLimitHelper():
def __init__(self, context: Context) -> None:
self.user_rate_limit: Dict[int, int] = {}
self.rate_limit_time: int = 60
self.rate_limit_count: int = 10
rl = context.config_helper.platform_settings.rate_limit
self.rate_limit_time: int = rl.time
self.rate_limit_count: int = rl.count
self.user_frequency = {}
if 'limit' in context.base_config:
if 'count' in context.base_config['limit']:
self.rate_limit_count = context.base_config['limit']['count']
if 'time' in context.base_config['limit']:
self.rate_limit_time = context.base_config['limit']['time']
def check_frequency(self, session_id: str) -> bool:
'''
检查发言频率
@@ -59,12 +54,11 @@ class RateLimitHelper():
class ContentSafetyHelper():
def __init__(self, context: Context) -> None:
self.baidu_judge = None
if 'baidu_api' in context.base_config and \
'enable' in context.base_config['baidu_aip'] and \
context.base_config['baidu_aip']['enable']:
aip = context.config_helper.content_safety.baidu_aip
if aip.enable:
try:
from astrbot.message.baidu_aip_judge import BaiduJudge
self.baidu_judge = BaiduJudge(context.base_config['baidu_aip'])
self.baidu_judge = BaiduJudge(aip)
logger.info("已启用百度 AI 内容审核。")
except BaseException as e:
logger.error("百度 AI 内容审核初始化失败。")
@@ -107,22 +101,19 @@ class ContentSafetyHelper():
class MessageHandler():
def __init__(self, context: Context,
command_manager: CommandManager,
persist_manager: dbConn,
provider: Provider) -> None:
persist_manager: dbConn) -> None:
self.context = context
self.command_manager = command_manager
self.persist_manager = persist_manager
self.rate_limit_helper = RateLimitHelper(context)
self.content_safety_helper = ContentSafetyHelper(context)
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix
if self.llm_wake_prefix:
self.llm_wake_prefix = self.llm_wake_prefix.strip()
self.nicks = self.context.nick
self.provider = provider
self.reply_prefix = str(self.context.reply_prefix)
self.provider = self.context.llms[0].llm_instance if len(self.context.llms) > 0 else None
self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix)
self.llm_tools = FuncCall(self.provider)
def set_provider(self, provider: Provider):
self.provider = provider
@@ -148,7 +139,7 @@ class MessageHandler():
return
# remove the nick prefix
for nick in self.nicks:
for nick in self.context.config_helper.wake_prefix:
if msg_plain.startswith(nick):
msg_plain = msg_plain.removeprefix(nick)
break
@@ -187,12 +178,6 @@ class MessageHandler():
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
break
# web_search = self.context.web_search
# if not web_search and msg_plain.startswith("ws"):
# # leverage web search feature
# web_search = True
# msg_plain = msg_plain.removeprefix("ws").strip()
try:
if not self.llm_tools.empty():
# tools-use
-1
View File
@@ -2,7 +2,6 @@ from dataclasses import dataclass
class DashBoardData():
stats: dict = {}
configs: dict = {}
@dataclass
class Response():
+57 -521
View File
@@ -1,537 +1,73 @@
import threading
import asyncio
from . import DashBoardData
from typing import Union, Optional
from util.cmd_config import CmdConfig
from dataclasses import dataclass
from util.cmd_config import AstrBotConfig
from dataclasses import dataclass, asdict
from util.plugin_dev.api.v1.config import update_config
from SparkleLogging.utils.core import LogManager
from logging import Logger
from type.types import Context
from type.config import CONFIG_METADATA_2
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@dataclass
class DashBoardConfig():
config_type: str
name: Optional[str] = None
description: Optional[str] = None
path: Optional[str] = None # 仅 item 才需要
body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要
value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要
val_type: Optional[str] = None # 仅 item 才需要
class DashBoardHelper():
def __init__(self, context: Context, dashboard_data: DashBoardData):
dashboard_data.configs = {
"data": []
}
def __init__(self, context: Context):
self.context = context
self.parse_default_config(dashboard_data, context.base_config)
self.config_key_dont_show = ['dashboard', 'config_version']
# 将 config.yaml、 中的配置解析到 dashboard_data.configs 中
def parse_default_config(self, dashboard_data: DashBoardData, config: dict):
try:
qq_official_platform_group = DashBoardConfig(
config_type="group",
name="QQ(官方)",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用 QQ_OFFICIAL 平台",
description="官方的接口,仅支持 QQ 频道。详见 q.qq.com",
value=config['qqbot']['enable'],
path="qqbot.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="QQ机器人APPID",
description="详见 q.qq.com",
value=config['qqbot']['appid'],
path="qqbot.appid",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="QQ机器人令牌",
description="详见 q.qq.com",
value=config['qqbot']['token'],
path="qqbot.token",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="QQ机器人 Secret",
description="详见 q.qq.com",
value=config['qqbot_secret'],
path="qqbot_secret",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否允许 QQ 频道私聊",
description="如果启用,机器人会响应私聊消息。",
value=config['direct_message_mode'],
path="direct_message_mode",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否接收QQ群消息",
description="需要机器人有相应的群消息接收权限。在 q.qq.com 上查看。",
value=config['qqofficial_enable_group_message'],
path="qqofficial_enable_group_message",
),
]
)
qq_gocq_platform_group = DashBoardConfig(
config_type="group",
name="QQ(nakuru)",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用",
description="",
value=config['gocqbot']['enable'],
path="gocqbot.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="HTTP 服务器地址",
description="",
value=config['gocq_host'],
path="gocq_host",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="HTTP 服务器端口",
description="",
value=config['gocq_http_port'],
path="gocq_http_port",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="WebSocket 服务器端口",
description="目前仅支持正向 WebSocket",
value=config['gocq_websocket_port'],
path="gocq_websocket_port",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应群消息",
description="",
value=config['gocq_react_group'],
path="gocq_react_group",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应私聊消息",
description="",
value=config['gocq_react_friend'],
path="gocq_react_friend",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应群成员增加消息",
description="",
value=config['gocq_react_group_increase'],
path="gocq_react_group_increase",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="是否响应频道消息",
description="",
value=config['gocq_react_guild'],
path="gocq_react_guild",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="转发阈值(字符数)",
description="机器人回复的消息长度超出这个值后,会被折叠成转发卡片发出以减少刷屏。",
value=config['qq_forward_threshold'],
path="qq_forward_threshold",
),
]
)
qq_aiocqhttp_platform_group = DashBoardConfig(
config_type="group",
name="QQ(aiocqhttp)",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启用",
description="",
value=config['aiocqhttp']['enable'],
path="aiocqhttp.enable",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="WebSocket 反向连接 host",
description="",
value=config['aiocqhttp']['ws_reverse_host'],
path="aiocqhttp.ws_reverse_host",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="WebSocket 反向连接 port",
description="",
value=config['aiocqhttp']['ws_reverse_port'],
path="aiocqhttp.ws_reverse_port",
),
]
)
general_platform_detail_group = DashBoardConfig(
config_type="group",
name="通用平台配置",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启动消息文字转图片",
description="启动后,机器人会将消息转换为图片发送,以降低风控风险。",
value=config['qq_pic_mode'],
path="qq_pic_mode",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="消息限制时间",
description="在此时间内,机器人不会回复同一个用户的消息。单位:秒",
value=config['limit']['time'],
path="limit.time",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="消息限制次数",
description="在上面的时间内,如果用户发送消息超过此次数,则机器人不会回复。单位:次",
value=config['limit']['count'],
path="limit.count",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="回复前缀",
description="[xxxx] 你好! 其中xxxx是你可以填写的前缀。如果为空则不显示。",
value=config['reply_prefix'],
path="reply_prefix",
),
DashBoardConfig(
config_type="item",
val_type="list",
name="通用管理员用户 ID(支持多个管理员)。通过 !myid 指令获取。",
description="",
value=config['other_admins'],
path="other_admins",
),
DashBoardConfig(
config_type="item",
val_type="bool",
name="独立会话",
description="是否启用独立会话模式,即 1 个用户自然账号 1 个会话。",
value=config['uniqueSessionMode'],
path="uniqueSessionMode",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="LLM 唤醒词",
description="如果不为空, 那么只有当消息以此词开头时,才会调用大语言模型进行回复。如设置为 /chat,那么只有当消息以 /chat 开头时,才会调用大语言模型进行回复。",
value=config['llm_wake_prefix'],
path="llm_wake_prefix",
)
]
)
openai_official_llm_group = DashBoardConfig(
config_type="group",
name="OpenAI 官方接口类设置",
description="",
body=[
DashBoardConfig(
config_type="item",
val_type="list",
name="OpenAI API Key",
description="OpenAI API 的 Key。支持使用非官方但兼容的 API(第三方中转key)。",
value=config['openai']['key'],
path="openai.key",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI API 节点地址(api base)",
description="OpenAI API 的节点地址,配合非官方 API 使用。如果不想填写,那么请填写 none",
value=config['openai']['api_base'],
path="openai.api_base",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI model",
description="OpenAI LLM 模型。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['model'],
path="openai.chatGPTConfigs.model",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="OpenAI max_tokens",
description="OpenAI 最大生成长度。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['max_tokens'],
path="openai.chatGPTConfigs.max_tokens",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI temperature",
description="OpenAI 温度。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['temperature'],
path="openai.chatGPTConfigs.temperature",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI top_p",
description="OpenAI top_p。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['top_p'],
path="openai.chatGPTConfigs.top_p",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI frequency_penalty",
description="OpenAI frequency_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['frequency_penalty'],
path="openai.chatGPTConfigs.frequency_penalty",
),
DashBoardConfig(
config_type="item",
val_type="float",
name="OpenAI presence_penalty",
description="OpenAI presence_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['chatGPTConfigs']['presence_penalty'],
path="openai.chatGPTConfigs.presence_penalty",
),
DashBoardConfig(
config_type="item",
val_type="int",
name="OpenAI 总生成长度限制",
description="OpenAI 总生成长度限制。详见 https://platform.openai.com/docs/api-reference/chat",
value=config['openai']['total_tokens_limit'],
path="openai.total_tokens_limit",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成模型",
description="OpenAI 图像生成模型。",
value=config['openai_image_generate']['model'],
path="openai_image_generate.model",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成大小",
description="OpenAI 图像生成大小。",
value=config['openai_image_generate']['size'],
path="openai_image_generate.size",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成风格",
description="OpenAI 图像生成风格。修改前请参考 OpenAI 官方文档",
value=config['openai_image_generate']['style'],
path="openai_image_generate.style",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="OpenAI 图像生成质量",
description="OpenAI 图像生成质量。修改前请参考 OpenAI 官方文档",
value=config['openai_image_generate']['quality'],
path="openai_image_generate.quality",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="问题题首提示词",
description="如果填写了此项,在每个对大语言模型的请求中,都会在问题前加上此提示词。",
value=config['llm_env_prompt'],
path="llm_env_prompt",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="默认人格文本",
description="默认人格文本",
value=config['default_personality_str'],
path="default_personality_str",
),
]
)
baidu_aip_group = DashBoardConfig(
config_type="group",
name="百度内容审核",
description="需要去申请",
body=[
DashBoardConfig(
config_type="item",
val_type="bool",
name="启动百度内容审核服务",
description="",
value=config['baidu_aip']['enable'],
path="baidu_aip.enable"
),
DashBoardConfig(
config_type="item",
val_type="str",
name="APP ID",
description="",
value=config['baidu_aip']['app_id'],
path="baidu_aip.app_id"
),
DashBoardConfig(
config_type="item",
val_type="str",
name="API KEY",
description="",
value=config['baidu_aip']['api_key'],
path="baidu_aip.api_key"
),
DashBoardConfig(
config_type="item",
val_type="str",
name="SECRET KEY",
description="",
value=config['baidu_aip']['secret_key'],
path="baidu_aip.secret_key"
)
]
)
other_group = DashBoardConfig(
config_type="group",
name="其他配置",
description="其他配置描述",
body=[
DashBoardConfig(
config_type="item",
val_type="str",
name="HTTP 代理地址",
description="建议上下一致",
value=config['http_proxy'],
path="http_proxy",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="HTTPS 代理地址",
description="建议上下一致",
value=config['https_proxy'],
path="https_proxy",
),
DashBoardConfig(
config_type="item",
val_type="str",
name="面板用户名",
description="是的,就是你理解的这个面板的用户名",
value=config['dashboard_username'],
path="dashboard_username",
),
]
)
dashboard_data.configs['data'] = [
qq_official_platform_group,
qq_gocq_platform_group,
general_platform_detail_group,
openai_official_llm_group,
other_group,
baidu_aip_group,
qq_aiocqhttp_platform_group
]
except Exception as e:
logger.error(f"配置文件解析错误:{e}")
raise e
def save_config(self, post_config: list, namespace: str):
'''
根据 path 解析并保存配置
'''
queue = post_config
while len(queue) > 0:
config = queue.pop(0)
if config['config_type'] == "group":
for item in config['body']:
queue.append(item)
elif config['config_type'] == "item":
if config['path'] is None or config['path'] == "":
def validate_config(self, data):
errors = []
# 递归验证数据
def validate(data, path=""):
for key, meta in CONFIG_METADATA_2.items():
if key not in data:
if key not in self.config_key_dont_show:
# 这些key不会传给前端,所以不需要验证
errors.append(f"Missing key: {path}{key}")
continue
value = data[key]
if meta["type"] == "int" and not isinstance(value, int):
errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}")
elif meta["type"] == "bool" and not isinstance(value, bool):
errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}")
elif meta["type"] == "string" and not isinstance(value, str):
errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}")
elif meta["type"] == "list" and not isinstance(value, list):
errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}")
for item in value:
validate(item, meta["items"], path=f"{path}{key}.")
elif meta["type"] == "dict" and not isinstance(value, dict):
errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
validate(data)
# hardcode warning
data['config_version'] = self.context.config_helper.config_version
data['dashboard'] = asdict(self.context.config_helper.dashboard)
return errors
path = config['path'].split('.')
if len(path) == 0:
continue
def save_astrbot_config(self, post_config: dict):
'''验证并保存配置'''
errors = self.validate_config(post_config)
if errors:
raise ValueError(f"格式校验未通过: {errors}")
self.context.config_helper.flush_config(post_config)
def save_extension_config(self, post_config: dict):
if 'namespace' not in post_config:
raise ValueError("Missing key: namespace")
if 'config' not in post_config:
raise ValueError("Missing key: config")
if config['val_type'] == "bool":
self._write_config(
namespace, config['path'], config['value'])
elif config['val_type'] == "str":
self._write_config(
namespace, config['path'], config['value'])
elif config['val_type'] == "int":
try:
self._write_config(
namespace, config['path'], int(config['value']))
except:
raise ValueError(f"配置项 {config['name']} 的值必须是整数")
elif config['val_type'] == "float":
try:
self._write_config(
namespace, config['path'], float(config['value']))
except:
raise ValueError(f"配置项 {config['name']} 的值必须是浮点数")
elif config['val_type'] == "list":
if config['value'] is None:
self._write_config(namespace, config['path'], [])
elif not isinstance(config['value'], list):
raise ValueError(f"配置项 {config['name']} 的值必须是列表")
self._write_config(
namespace, config['path'], config['value'])
else:
raise NotImplementedError(
f"未知或者未实现的配置项类型:{config['val_type']}")
def _write_config(self, namespace: str, key: str, value):
if namespace == "" or namespace.startswith("internal_"):
# 机器人自带配置,存到 config.yaml
self.context.config_helper.put_by_dot_str(key, value)
else:
namespace = post_config['namespace']
config: list = post_config['config'][0]['body']
for item in config:
key = item['path']
value = item['value']
typ = item['val_type']
if typ == 'int':
if not value.isdigit():
raise ValueError(f"Invalid type for {namespace}.{key}: expected int, got {type(value).__name__}")
value = int(value)
update_config(namespace, key, value)
+75 -126
View File
@@ -19,6 +19,7 @@ from dashboard.helper import DashBoardHelper
from util.io import get_local_ip_addresses
from model.plugin.manager import PluginManager
from util.updator.astrbot_updator import AstrBotUpdator
from type.config import CONFIG_METADATA_2
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -30,9 +31,10 @@ class AstrBotDashBoard():
self.plugin_manager = plugin_manager
self.astrbot_updator = astrbot_updator
self.dashboard_data = DashBoardData()
self.dashboard_helper = DashBoardHelper(self.context, self.dashboard_data)
self.dashboard_helper = DashBoardHelper(self.context)
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
self.dashboard_be.json.sort_keys=False # 不按照字典排序
logging.getLogger('werkzeug').setLevel(logging.ERROR)
self.dashboard_be.logger.setLevel(logging.ERROR)
@@ -68,8 +70,8 @@ class AstrBotDashBoard():
@self.dashboard_be.post("/api/authenticate")
def authenticate():
username = self.context.base_config.get("dashboard_username", "")
password = self.context.base_config.get("dashboard_password", "")
username = self.context.config_helper.dashboard.username
password = self.context.config_helper.dashboard.password
# 获得请求体
post_data = request.json
if post_data["username"] == username and post_data["password"] == password:
@@ -90,7 +92,7 @@ class AstrBotDashBoard():
@self.dashboard_be.post("/api/change_password")
def change_password():
password = self.context.base_config.get("dashboard_password", "")
password = self.context.config_helper.dashboard.password
# 获得请求体
post_data = request.json
if post_data["password"] == password:
@@ -130,40 +132,53 @@ class AstrBotDashBoard():
@self.dashboard_be.get("/api/configs")
def get_configs():
# 如果params中有namespace,则返回该namespace下的配置
# 否则返回所有配置
# namespace 为空时返回 AstrBot 配置
# 否则返回指定 namespace 的插件配置
namespace = "" if "namespace" not in request.args else request.args["namespace"]
conf = self._get_configs(namespace)
return Response(
status="success",
message="",
data=conf
).__dict__
@self.dashboard_be.get("/api/config_outline")
def get_config_outline():
outline = self._generate_outline()
return Response(
status="success",
message="",
data=outline
).__dict__
@self.dashboard_be.post("/api/configs")
def post_configs():
post_configs = request.json
try:
self.on_post_configs(post_configs)
if not namespace:
return Response(
status="success",
message="保存成功~ 机器人将在 2 秒内重启以应用新的配置。",
message="",
data=self._get_astrbot_config()
).__dict__
return Response(
status="success",
message="",
data=self._get_extension_config(namespace)
).__dict__
@self.dashboard_be.post("/api/astrbot-configs")
def post_astrbot_configs():
post_configs = request.json
try:
self.save_astrbot_configs(post_configs)
return Response(
status="success",
message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。",
data=None
).__dict__
except Exception as e:
return Response(
status="error",
message=e.__str__(),
data=self.dashboard_data.configs
data=None
).__dict__
@self.dashboard_be.post("/api/extension-configs")
def post_extension_configs():
post_configs = request.json
try:
self.save_extension_configs(post_configs)
return Response(
status="success",
message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。",
data=None
).__dict__
except Exception as e:
return Response(
status="error",
message=e.__str__(),
data=None
).__dict__
@self.dashboard_be.get("/api/extensions")
@@ -365,107 +380,41 @@ class AstrBotDashBoard():
data=None
).__dict__
def on_post_configs(self, post_configs: dict):
def save_astrbot_configs(self, post_configs: dict):
try:
if 'base_config' in post_configs:
self.dashboard_helper.save_config(
post_configs['base_config'], namespace='') # 基础配置
self.dashboard_helper.save_config(
post_configs['config'], namespace=post_configs['namespace']) # 选定配置
self.dashboard_helper.parse_default_config(
self.dashboard_data, self.context.config_helper.get_all())
# 重启
threading.Thread(target=self.astrbot_updator._reboot,
args=(2, self.context), daemon=True).start()
self.dashboard_helper.save_astrbot_config(post_configs)
threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start()
except Exception as e:
raise e
def save_extension_configs(self, post_configs: dict):
try:
self.dashboard_helper.save_extension_config(post_configs)
threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start()
except Exception as e:
raise e
def _get_astrbot_config(self):
config = self.context.config_helper.to_dict()
for key in self.dashboard_helper.config_key_dont_show:
if key in config:
del config[key]
return {
"metadata": CONFIG_METADATA_2,
"config": config,
}
def _get_configs(self, namespace: str):
if namespace == "":
ret = [self.dashboard_data.configs['data'][4],
self.dashboard_data.configs['data'][5],]
elif namespace == "internal_platform_qq_official":
ret = [self.dashboard_data.configs['data'][0],]
elif namespace == "internal_platform_qq_gocq":
ret = [self.dashboard_data.configs['data'][1],]
elif namespace == "internal_platform_general": # 全局平台配置
ret = [self.dashboard_data.configs['data'][2],]
elif namespace == "internal_llm_openai_official":
ret = [self.dashboard_data.configs['data'][3],]
elif namespace == "internal_platform_qq_aiocqhttp":
ret = [self.dashboard_data.configs['data'][6],]
else:
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
ret = [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
return ret
def _generate_outline(self):
'''
生成配置大纲。目前分为 platform(消息平台配置) 和 llm(语言模型配置) 两大类。
插件的info函数中如果带了plugin_type字段,则会被归类到对应的大纲中。目前仅支持 platform 和 llm 两种类型。
'''
outline = [
{
"type": "platform",
"name": "配置通用消息平台",
"body": [
{
"title": "通用",
"desc": "通用平台配置",
"namespace": "internal_platform_general",
"tag": ""
},
{
"title": "QQ(官方)",
"desc": "QQ官方API。支持频道、群、私聊(需获得群权限)",
"namespace": "internal_platform_qq_official",
"tag": ""
},
{
"title": "QQ(nakuru)",
"desc": "适用于 go-cqhttp",
"namespace": "internal_platform_qq_gocq",
"tag": ""
},
{
"title": "QQ(aiocqhttp)",
"desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。",
"namespace": "internal_platform_qq_aiocqhttp",
"tag": ""
}
]
},
{
"type": "llm",
"name": "配置 LLM",
"body": [
{
"title": "OpenAI Official",
"desc": "也支持使用官方接口的中转服务",
"namespace": "internal_llm_openai_official",
"tag": ""
}
]
}
]
for plugin in self.context.cached_plugins:
for item in outline:
if item['type'] == plugin.metadata.plugin_type:
item['body'].append({
"title": plugin.metadata.plugin_name,
"desc": plugin.metadata.desc,
"namespace": plugin.metadata.plugin_name,
"tag": plugin.metadata.plugin_name
})
return outline
def _get_extension_config(self, namespace: str):
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
return [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
async def get_log_history(self):
try:
+15 -18
View File
@@ -8,7 +8,6 @@ from type.types import Context
from type.config import VERSION
from SparkleLogging.utils.core import LogManager
from logging import Logger
from nakuru.entities.components import Image
from util.agent.web_searcher import search_from_bing, fetch_website_content
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -63,12 +62,12 @@ class InternalCommandHandler:
return CommandResult().message("你没有权限使用该指令。")
l = message_str.split(" ")
if len(l) == 1:
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词{context.nick}")
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词{context.config_helper.wake_prefix[0]}")
nick = l[1].strip()
if not nick:
return CommandResult().message("wake: 请指定唤醒词。")
context.config_helper.put("nick_qq", nick)
context.nick = tuple(nick)
context.config_helper.wake_prefix = [nick]
context.config_helper.save_config()
return CommandResult(
hit=True,
success=True,
@@ -90,11 +89,7 @@ class InternalCommandHandler:
ret = f"当前已经是最新版本 v{VERSION}"
else:
ret = f"发现新版本 {update_info.version},更新内容如下:\n---\n{update_info.body}\n---\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。"
return CommandResult(
hit=True,
success=False,
message_chain=ret,
)
return CommandResult().message(ret)
else:
if tokens.get(1) == "latest":
try:
@@ -184,7 +179,7 @@ class InternalCommandHandler:
async with session.get("https://soulter.top/channelbot/notice.json") as resp:
notice = (await resp.json())["notice"]
except BaseException as e:
logger.warn("An error occurred while fetching astrbot notice. Never mind, it's not important.")
logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.")
msg = "# Help Center\n## 指令列表\n"
for key, value in self.manager.commands_handler.items():
@@ -209,10 +204,11 @@ class InternalCommandHandler:
return CommandResult(
hit=True,
success=True,
message_chain=f"网页搜索功能当前状态: {context.web_search}",
message_chain=f"网页搜索功能当前状态: {context.config_helper.llm_settings.web_search}",
)
elif l[1] == 'on':
context.web_search = True
context.config_helper.llm_settings.web_search = True
context.config_helper.save_config()
context.register_llm_tool("web_search", [{
"type": "string",
"name": "keyword",
@@ -236,7 +232,8 @@ class InternalCommandHandler:
message_chain="已开启网页搜索",
)
elif l[1] == 'off':
context.web_search = False
context.config_helper.llm_settings.web_search = False
context.config_helper.save_config()
context.unregister_llm_tool("web_search")
context.unregister_llm_tool("fetch_website_content")
@@ -253,17 +250,17 @@ class InternalCommandHandler:
)
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
p = context.t2i_mode
p = context.config_helper.t2i
if p:
context.config_helper.put("qq_pic_mode", False)
context.t2i_mode = False
context.config_helper.t2i = False
context.config_helper.save_config()
return CommandResult(
hit=True,
success=True,
message_chain="已关闭文本转图片模式。",
)
context.config_helper.put("qq_pic_mode", True)
context.t2i_mode = True
context.config_helper.t2i = True
context.config_helper.save_config()
return CommandResult(
hit=True,
+5 -4
View File
@@ -61,10 +61,11 @@ class Platform():
return ret[:100] if len(ret) > 100 else ret
def check_nick(self, message_str: str) -> bool:
if self.context.nick:
for nick in self.context.nick:
if nick and message_str.strip().startswith(nick):
return True
w = self.context.config_helper.wake_prefix
if not w: return False
for nick in w:
if nick and message_str.strip().startswith(nick):
return True
return False
async def convert_to_t2i_chain(self, message_result: list) -> list:
+33 -23
View File
@@ -6,6 +6,12 @@ from type.types import Context
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import (
PlatformConfig,
AiocqhttpPlatformConfig,
NakuruPlatformConfig,
QQOfficialPlatformConfig
)
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -13,36 +19,40 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
class PlatformManager():
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
self.context = context
self.config = context.base_config
self.msg_handler = message_handler
def load_platforms(self):
tasks = []
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
logger.info("启用 QQ(nakuru 适配器)")
tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter"))
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
logger.info("启用 QQ(aiocqhttp 适配器)")
tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter"))
platforms = self.context.config_helper.platform
logger.info(f"加载 {len(platforms)} 个机器人消息平台...")
for platform in platforms:
if not platform.enable:
continue
if platform.name == "qq_official":
assert isinstance(platform, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
logger.info(f"加载 QQ官方 机器人消息平台 (appid: {platform.appid})")
tasks.append(asyncio.create_task(self.qqofficial_bot(platform), name="qqofficial-adapter"))
elif platform.name == "nakuru":
assert isinstance(platform, NakuruPlatformConfig), "nakuru: 无法识别的配置类型。"
logger.info(f"加载 QQ(nakuru) 机器人消息平台 ({platform.host}, {platform.websocket_port}, {platform.port})")
tasks.append(asyncio.create_task(self.nakuru_bot(platform), name="nakuru-adapter"))
elif platform.name == "aiocqhttp":
assert isinstance(platform, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
logger.info("加载 QQ(aiocqhttp) 机器人消息平台")
tasks.append(asyncio.create_task(self.aiocq_bot(platform), name="aiocqhttp-adapter"))
# QQ频道
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
logger.info("启用 QQ(官方 API) 机器人消息平台")
tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter"))
return tasks
async def gocq_bot(self):
async def nakuru_bot(self, config: NakuruPlatformConfig):
'''
运行 QQ(nakuru 适配器)
'''
from model.platform.qq_nakuru import QQGOCQ
from model.platform.qq_nakuru import QQNakuru
noticed = False
host = self.config.get("gocq_host", "127.0.0.1")
port = self.config.get("gocq_websocket_port", 6700)
http_port = self.config.get("gocq_http_port", 5700)
host = config.host
port = config.websocket_port
http_port = config.port
logger.info(
f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}")
while True:
@@ -56,30 +66,30 @@ class PlatformManager():
logger.info("nakuru 适配器已连接。")
break
try:
qq_gocq = QQGOCQ(self.context, self.msg_handler)
qq_gocq = QQNakuru(self.context, self.msg_handler, config)
self.context.platforms.append(RegisteredPlatform(
platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
await qq_gocq.run()
except BaseException as e:
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
def aiocq_bot(self):
def aiocq_bot(self, config):
'''
运行 QQ(aiocqhttp 适配器)
'''
from model.platform.qq_aiocqhttp import AIOCQHTTP
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler)
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler, config)
self.context.platforms.append(RegisteredPlatform(
platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal"))
return qq_aiocqhttp.run_aiocqhttp()
def qqchan_bot(self):
def qqofficial_bot(self, config):
'''
运行 QQ 官方机器人适配器
'''
try:
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(self.context, self.msg_handler)
qqchannel_bot = QQOfficial(self.context, self.msg_handler, config)
self.context.platforms.append(RegisteredPlatform(
platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
return qqchannel_bot.run()
+21 -12
View File
@@ -13,19 +13,26 @@ from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
logger: Logger = LogManager.GetLogger(log_name='astrbot')
class AIOCQHTTP(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
def __init__(self, context: Context,
message_handler: MessageHandler,
platform_config: PlatformConfig) -> None:
super().__init__("aiocqhttp", context)
assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
self.message_handler = message_handler
self.waiting = {}
self.context = context
self.unique_session = self.context.unique_session
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
self.host = self.context.base_config['aiocqhttp']['ws_reverse_host']
self.port = self.context.base_config['aiocqhttp']['ws_reverse_port']
self.config = platform_config
self.unique_session = context.config_helper.platform_settings.unique_session
self.announcement = context.config_helper.platform_settings.welcome_message_when_join
self.host = platform_config.ws_reverse_host
self.port = platform_config.ws_reverse_port
self.admins = context.config_helper.admins_id
def convert_message(self, event: Event) -> AstrBotMessage:
@@ -105,7 +112,7 @@ class AIOCQHTTP(Platform):
while self.context.running:
await asyncio.sleep(1)
def pre_check(self, message: AstrBotMessage) -> bool:
async def pre_check(self, message: AstrBotMessage) -> bool:
# if message chain contains Plain components or
# At components which points to self_id, return True
if message.type == MessageType.FRIEND_MESSAGE:
@@ -114,7 +121,7 @@ class AIOCQHTTP(Platform):
if isinstance(comp, At) and str(comp.qq) == message.self_id:
return True, "at"
# check commands which ignore prefix
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
if await self.context.command_manager.check_command_ignore_prefix(message.message_str):
return True, "command"
# check nicks
if self.check_nick(message.message_str):
@@ -125,14 +132,13 @@ class AIOCQHTTP(Platform):
logger.info(
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
ok, reason = self.pre_check(message)
ok, reason = await self.pre_check(message)
if not ok:
return
# 解析 role
sender_id = str(message.sender.user_id)
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
@@ -167,6 +173,8 @@ class AIOCQHTTP(Platform):
# 如果是等待回复的消息
if message.session_id in self.waiting and self.waiting[message.session_id] == '':
self.waiting[message.session_id] = message
return message_result
async def reply_msg(self,
@@ -182,17 +190,18 @@ class AIOCQHTTP(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
await self._reply(message, rendered_images)
return
return rendered_images
except BaseException as e:
logger.warn(traceback.format_exc())
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
await self._reply(message, res)
return res
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
await self.record_metrics()
+22 -16
View File
@@ -18,6 +18,7 @@ from type.command import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -28,47 +29,53 @@ class FakeSource:
self.group_id = group_id
class QQGOCQ(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
class QQNakuru(Platform):
def __init__(self, context: Context,
message_handler: MessageHandler,
platform_config: PlatformConfig) -> None:
super().__init__("nakuru", context)
assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。"
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.message_handler = message_handler
self.waiting = {}
self.context = context
self.unique_session = self.context.unique_session
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
self.unique_session = context.config_helper.platform_settings.unique_session
self.announcement = context.config_helper.platform_settings.welcome_message_when_join
self.config = platform_config
self.admins = context.config_helper.admins_id
self.client = CQHTTP(
host=self.context.base_config.get("gocq_host", "127.0.0.1"),
port=self.context.base_config.get("gocq_websocket_port", 6700),
http_port=self.context.base_config.get("gocq_http_port", 5700),
host=self.config.host,
port=self.config.websocket_port,
http_port=self.config.port
)
gocq_app = self.client
@gocq_app.receiver("GroupMessage")
async def _(app: CQHTTP, source: GroupMessage):
if self.context.base_config.get("gocq_react_group", True):
if self.config.enable_group:
abm = self.convert_message(source)
await self.handle_msg(abm)
@gocq_app.receiver("FriendMessage")
async def _(app: CQHTTP, source: FriendMessage):
if self.context.base_config.get("gocq_react_friend", True):
if self.config.enable_direct_message:
abm = self.convert_message(source)
await self.handle_msg(abm)
@gocq_app.receiver("GroupMemberIncrease")
async def _(app: CQHTTP, source: GroupMemberIncrease):
if self.context.base_config.get("gocq_react_group_increase", True):
if self.config.enable_group_increase:
await app.sendGroupMessage(source.group_id, [
Plain(text=self.announcement)
])
@gocq_app.receiver("GuildMessage")
async def _(app: CQHTTP, source: GuildMessage):
if self.cc.get("gocq_react_guild", True):
if self.config.enable_guild:
abm = self.convert_message(source)
await self.handle_msg(abm)
@@ -117,8 +124,7 @@ class QQGOCQ(Platform):
# 解析 role
sender_id = str(message.raw_message.user_id)
if sender_id == self.context.base_config.get('admin_qq', '') or \
sender_id in self.context.base_config.get('other_admins', []):
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
@@ -179,7 +185,7 @@ class QQGOCQ(Platform):
res = [Plain(text=res), ]
# if image mode, put all Plain texts into a new picture.
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(res)
if rendered_images:
try:
@@ -226,7 +232,7 @@ class QQGOCQ(Platform):
plain_text_len += len(i.text)
elif isinstance(i, Image):
image_num += 1
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
if plain_text_len > self.context.config_helper.platform_settings.forward_threshold or image_num > 1:
# 删除At
for i in message_chain:
if isinstance(i, At):
+21 -30
View File
@@ -19,6 +19,7 @@ from nakuru.entities.components import *
from SparkleLogging.utils.core import LogManager
from logging import Logger
from astrbot.message.handler import MessageHandler
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -52,32 +53,37 @@ class botClient(Client):
class QQOfficial(Platform):
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
def __init__(self, context: Context,
message_handler: MessageHandler,
platform_config: PlatformConfig,
test_mode = False) -> None:
super().__init__("qqofficial", context)
assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.message_handler = message_handler
self.waiting: dict = {}
self.context = context
self.config = platform_config
self.admins = context.config_helper.admins_id
self.appid = context.base_config['qqbot']['appid']
self.token = context.base_config['qqbot']['token']
self.secret = context.base_config['qqbot_secret']
self.unique_session = context.unique_session
qq_group = context.base_config['qqofficial_enable_group_message']
self.appid = platform_config.appid
self.secret = platform_config.secret
self.unique_session = context.config_helper.platform_settings.unique_session
qq_group = platform_config.enable_group_c2c
guild_dm = platform_config.enable_guild_direct_message
if qq_group:
self.intents = botpy.Intents(
public_messages=True,
public_guild_messages=True,
direct_message=context.base_config['direct_message_mode']
direct_message=guild_dm
)
else:
self.intents = botpy.Intents(
public_guild_messages=True,
direct_message=context.base_config['direct_message_mode']
direct_message=guild_dm
)
self.client = botClient(
intents=self.intents,
@@ -169,24 +175,10 @@ class QQOfficial(Platform):
return abm
def run(self):
try:
return self.client.start(
appid=self.appid,
secret=self.secret
)
except BaseException as e:
# 早期的 qq-botpy 版本使用 token 登录。
logger.error(traceback.format_exc())
self.client = botClient(
intents=self.intents,
bot_log=False,
timeout=20,
)
self.client.set_platform(self)
return self.client.start(
appid=self.appid,
token=self.token
)
return self.client.start(
appid=self.appid,
secret=self.secret
)
async def handle_msg(self, message: AstrBotMessage):
assert isinstance(message.raw_message, (botpy.message.Message,
@@ -211,8 +203,7 @@ class QQOfficial(Platform):
# 解析出 role
sender_id = message.sender.user_id
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
sender_id in self.context.base_config.get('other_admins', None):
if sender_id in self.admins:
role = 'admin'
else:
role = 'member'
@@ -252,7 +243,7 @@ class QQOfficial(Platform):
msg_ref = None
rendered_images = []
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
rendered_images = await self.convert_to_t2i_chain(result_message)
if isinstance(result_message, list):
+25 -30
View File
@@ -8,18 +8,18 @@ import traceback
import base64
from openai import AsyncOpenAI
from openai.types.images_response import ImagesResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai._exceptions import *
from util.io import download_image_by_url
from astrbot.persist.helper import dbConn
from model.provider.provider import Provider
from util.cmd_config import LLMConfig
from SparkleLogging.utils.core import LogManager
from logging import Logger
from typing import List, Dict
from type.types import Context
from dataclasses import asdict
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@@ -47,22 +47,16 @@ MODELS = {
}
class ProviderOpenAIOfficial(Provider):
def __init__(self, context: Context) -> None:
def __init__(self, llm_config: LLMConfig) -> None:
super().__init__()
os.makedirs("data/openai", exist_ok=True)
self.context = context
self.key_data_path = "data/openai/keys.json"
self.api_keys = []
self.chosen_api_key = None
self.base_url = None
self.llm_config = llm_config
self.keys_data = {} # 记录超额
cfg = context.base_config['openai']
if cfg['key']: self.api_keys = cfg['key']
if cfg['api_base']: self.base_url = cfg['api_base']
if llm_config.key: self.api_keys = llm_config.key
if llm_config.api_base: self.base_url = llm_config.api_base
if not self.api_keys:
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
else:
@@ -75,18 +69,21 @@ class ProviderOpenAIOfficial(Provider):
api_key=self.chosen_api_key,
base_url=self.base_url
)
self.model_configs: Dict = cfg['chatGPTConfigs']
super().set_curr_model(self.model_configs['model'])
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
super().set_curr_model(llm_config.model_config.model)
if llm_config.image_generation_model_config:
self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config)
self.session_memory: Dict[str, List] = {} # 会话记忆
self.session_memory_lock = threading.Lock()
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小
logger.info("正在载入分词器 cl100k_base...")
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
logger.info("分词器载入完成。")
self.DEFAULT_PERSONALITY = context.default_personality
self.DEFAULT_PERSONALITY = {
"prompt": self.llm_config.default_personality,
"name": "default"
}
self.curr_personality = self.DEFAULT_PERSONALITY
self.session_personality = {} # 记录了某个session是否已设置人格。
# 从 SQLite DB 读取历史记录
@@ -133,7 +130,7 @@ class ProviderOpenAIOfficial(Provider):
self.session_personality = {} # 重置
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
tokens_num = len(encoded_prompt)
model = self.model_configs['model']
model = self.get_curr_model()
if model in MODELS and tokens_num > MODELS[model] - 500:
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
@@ -172,7 +169,7 @@ class ProviderOpenAIOfficial(Provider):
for record in self.session_memory[session_id]:
if "user" in record and record['user']:
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
logger.warn(f"由于当前模型 {self.get_curr_model()} 不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
continue
context.append(record['user'])
if "AI" in record and record['AI']:
@@ -184,7 +181,7 @@ class ProviderOpenAIOfficial(Provider):
'''
是否是 LVM
'''
return self.model_configs['model'].startswith("gpt-4")
return self.get_curr_model().startswith("gpt-4")
async def get_models(self):
try:
@@ -237,7 +234,7 @@ class ProviderOpenAIOfficial(Provider):
self.session_memory[session_id].append(message)
# 根据 模型的上下文窗口 淘汰掉多余的记录
curr_model = self.model_configs['model']
curr_model = self.get_curr_model()
if curr_model in MODELS:
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
# if message['usage_tokens'] > maxium_tokens_num:
@@ -316,7 +313,7 @@ class ProviderOpenAIOfficial(Provider):
# 1. 可以保证之后 pop 的时候不会出现问题
# 2. 可以保证不会超过最大 token 数
_encoded_prompt = self.tokenizer.encode(prompt)
curr_model = self.model_configs['model']
curr_model = self.get_curr_model()
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
prompt = self.tokenizer.decode(_encoded_prompt)
@@ -327,7 +324,7 @@ class ProviderOpenAIOfficial(Provider):
# 获取上下文,openai 格式
contexts = await self.retrieve_context(session_id)
conf = self.model_configs
conf = asdict(self.llm_config.model_config)
if extra_conf: conf.update(extra_conf)
# start request
@@ -358,10 +355,10 @@ class ProviderOpenAIOfficial(Provider):
if ok: continue
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
except BadRequestError as e:
retry += 1
logger.warn(f"OpenAI 请求异常:{e}")
if "image_url is only supported by certain models." in str(e):
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
raise e
raise Exception(f"当前模型 { self.get_curr_model() } 不支持图片输入,请更换模型。")
except RateLimitError as e:
if "You exceeded your current quota" in str(e):
self.keys_data[self.chosen_api_key] = False
@@ -437,11 +434,10 @@ class ProviderOpenAIOfficial(Provider):
'''
retry = 0
conf = self.image_generator_model_configs
super().accu_model_stat(model=conf['model'])
if not conf:
logger.error("OpenAI 图片生成模型配置不存在。")
raise Exception("OpenAI 图片生成模型配置不存在。")
super().accu_model_stat(model=conf['model'])
while retry < 3:
try:
images_response = await self.client.images.generate(
@@ -497,12 +493,11 @@ class ProviderOpenAIOfficial(Provider):
return contexts_str, len(self.session_memory[session_id])
def set_model(self, model: str):
self.model_configs['model'] = model
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
# TODO: 更新配置文件
super().set_curr_model(model)
def get_configs(self):
return self.model_configs
return asdict(self.llm_config)
def get_keys_data(self):
return self.keys_data
+6 -1
View File
@@ -1,3 +1,4 @@
import copy
from aiocqhttp import Event
class MockOneBotMessage():
@@ -10,4 +11,8 @@ class MockOneBotMessage():
return self.group_event_sample
def create_random_direct_message(self):
return self.friend_event_sample
return self.friend_event_sample
def create_msg(self, text: str):
self.group_event_sample.message = [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': text}, 'type': 'text'}]
return self.group_event_sample
+16 -7
View File
@@ -3,19 +3,19 @@ import botpy.message
class MockQQOfficialMessage():
def __init__(self):
# 这些数据已经经过去敏处理
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'}
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'}
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'}
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T19:58:52+08:00'}
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:06:32+08:00'}
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:15:24+08:00'}
self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
def create_random_group_message(self):
@@ -42,4 +42,13 @@ class MockQQOfficialMessage():
)
return mocked
def create_msg(self, text: str):
sample = self.group_plain_text_sample.copy()
sample['content'] = text
mocked = botpy.message.Message(
api=None,
event_id=self.group_event_id_sample,
data=sample
)
return mocked
+89 -4
View File
@@ -8,11 +8,14 @@ from tests.mocks.onebot import MockOneBotMessage
from astrbot.bootstrap import AstrBotBootstrap
from model.platform.qq_official import QQOfficial
from model.platform.qq_aiocqhttp import AIOCQHTTP
from model.provider.openai_official import ProviderOpenAIOfficial
from type.astrbot_message import *
from type.message_event import *
from SparkleLogging.utils.core import LogManager
from logging import Formatter
from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig
logger = LogManager.GetLogger(
log_name='astrbot',
out_to_console=True,
@@ -22,10 +25,23 @@ pytest_plugins = ('pytest_asyncio',)
os.environ['TEST_MODE'] = 'on'
bootstrap = AstrBotBootstrap()
asyncio.run(bootstrap.run())
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler)
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler)
llm_config = bootstrap.context.config_helper.llm[0]
llm_config.api_base = os.environ['OPENAI_API_BASE']
llm_config.key = [os.environ['OPENAI_API_KEY']]
llm_config.model_config.model = os.environ['LLM_MODEL']
llm_config.model_config.max_tokens = 1000
llm_provider = ProviderOpenAIOfficial(llm_config)
asyncio.run(bootstrap.run())
bootstrap.message_handler.provider = llm_provider
bootstrap.config_helper.wake_prefix = ["/"]
bootstrap.config_helper.admins_id = ["905617992"]
for p_config in bootstrap.context.config_helper.platform:
if isinstance(p_config, QQOfficialPlatformConfig):
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler, p_config)
elif isinstance(p_config, AiocqhttpPlatformConfig):
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler, p_config)
class TestBasicMessageHandle():
@pytest.mark.asyncio
@@ -62,4 +78,73 @@ class TestBasicMessageHandle():
event = MockOneBotMessage().create_random_direct_message()
abm = aiocqhttp.convert_message(event)
ret = await aiocqhttp.handle_msg(abm)
print(ret)
print(ret)
class TestInteralCommandHsandle():
def create(self, text: str):
event = MockOneBotMessage().create_msg(text)
abm = aiocqhttp.convert_message(event)
return abm
async def fast_test(self, text: str):
abm = self.create(text)
ret = await aiocqhttp.handle_msg(abm)
print(f"Command: {text}, Result: {ret.result_message}")
return ret
@pytest.mark.asyncio
async def test_config_save(self):
abm = self.create("/websearch on")
ret = await aiocqhttp.handle_msg(abm)
assert bootstrap.context.config_helper.llm_settings.web_search \
== bootstrap.config_helper.get("llm_settings")['web_search']
@pytest.mark.asyncio
async def test_websearch(self):
await self.fast_test("/websearch")
await self.fast_test("/websearch on")
await self.fast_test("/websearch off")
@pytest.mark.asyncio
async def test_help(self):
await self.fast_test("/help")
@pytest.mark.asyncio
async def test_myid(self):
await self.fast_test("/myid")
@pytest.mark.asyncio
async def test_wake(self):
await self.fast_test("/wake")
await self.fast_test("/wake #")
assert "#" in bootstrap.context.config_helper.wake_prefix
assert "#" in bootstrap.context.config_helper.get("wake_prefix")
await self.fast_test("#wake /")
@pytest.mark.asyncio
async def test_sleep(self):
await self.fast_test("/provider")
@pytest.mark.asyncio
async def test_update(self):
await self.fast_test("/update")
@pytest.mark.asyncio
async def test_t2i(self):
if not bootstrap.context.config_helper.t2i:
abm = self.create("/t2i")
await aiocqhttp.handle_msg(abm)
await self.fast_test("/help")
class TestLLMChat():
@pytest.mark.asyncio
async def test_llm_chat(self):
os.environ["TEST_LLM"] = "on"
ret = await llm_provider.text_chat("Just reply `ok`", "test")
print(ret)
event = MockOneBotMessage().create_msg("Just reply `ok`")
abm = aiocqhttp.convert_message(event)
ret = await aiocqhttp.handle_msg(abm)
print(ret)
os.environ["TEST_LLM"] = "off"
+281
View File
@@ -73,3 +73,284 @@ DEFAULT_CONFIG = {
"ws_reverse_port": 0,
}
}
# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D
DEFAULT_CONFIG_VERSION_2 = {
"config_version": 2,
"platform": [
{
"id": "default",
"name": "qq_official",
"enable": False,
"appid": "",
"secret": "",
"enable_group_c2c": True,
"enable_guild_direct_message": True,
},
{
"id": "default",
"name": "nakuru",
"enable": False,
"host": "172.0.0.1",
"port": 5700,
"websocket_port": 6700,
"enable_group": True,
"enable_guild": True,
"enable_direct_message": True,
"enable_group_increase": True,
},
{
"id": "default",
"name": "aiocqhttp",
"enable": False,
"ws_reverse_host": "",
"ws_reverse_port": 6199,
}
],
"platform_settings": {
"unique_session": False,
"welcome_message_when_join": "",
"rate_limit": {
"time": 60,
"count": 30,
},
"reply_prefix": "",
"forward_threshold": 200, # 转发消息的阈值
},
"llm": [
{
"id": "default",
"name": "openai",
"enable": True,
"key": [],
"api_base": "",
"prompt_prefix": "",
"default_personality": "",
"model_config": {
"model": "gpt-4o",
"max_tokens": 6000,
"temperature": 0.9,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
},
"image_generation_model_config": {
"enable": True,
"model": "dall-e-3",
"size": "1024x1024",
"style": "vivid",
"quality": "standard",
}
},
],
"llm_settings": {
"wake_prefix": "",
"web_search": False,
},
"content_safety": {
"baidu_aip": {
"enable": False,
"app_id": "",
"api_key": "",
"secret_key": "",
},
"internal_keywords": {
"enable": True,
"extra_keywords": [],
}
},
"wake_prefix": [],
"t2i": True,
"dump_history_interval": 10,
"admins_id": [],
"https_proxy": "",
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "",
"password": "",
},
}
# 这个是用于迁移旧版本配置文件的映射表
MAPPINGS_1_2 = [
[["qqbot", "enable"], ["platform", 0, "enable"]],
[["qqbot", "appid"], ["platform", 0, "appid"]],
[["qqbot", "token"], ["platform", 0, "secret"]],
[["qqofficial_enable_group_message"], ["platform", 0, "enable_group_c2c"]],
[["direct_message_mode"], ["platform", 0, "enable_guild_direct_message"]],
[["gocqbot", "enable"], ["platform", 1, "enable"]],
[["gocq_host"], ["platform", 1, "host"]],
[["gocq_http_port"], ["platform", 1, "port"]],
[["gocq_websocket_port"], ["platform", 1, "websocket_port"]],
[["gocq_react_group"], ["platform", 1, "enable_group"]],
[["gocq_react_guild"], ["platform", 1, "enable_guild"]],
[["gocq_react_friend"], ["platform", 1, "enable_direct_message"]],
[["gocq_react_group_increase"], ["platform", 1, "enable_group_increase"]],
[["aiocqhttp", "enable"], ["platform", 2, "enable"]],
[["aiocqhttp", "ws_reverse_host"], ["platform", 2, "ws_reverse_host"]],
[["aiocqhttp", "ws_reverse_port"], ["platform", 2, "ws_reverse_port"]],
[["uniqueSessionMode"], ["platform_settings", "unique_session"]],
[["qq_welcome"], ["platform_settings", "welcome_message_when_join"]],
[["limit", "time"], ["platform_settings", "rate_limit", "time"]],
[["limit", "count"], ["platform_settings", "rate_limit", "count"]],
[["reply_prefix"], ["platform_settings", "reply_prefix"]],
[["qq_forward_threshold"], ["platform_settings", "forward_threshold"]],
[["openai", "key"], ["llm", 0, "key"]],
[["openai", "api_base"], ["llm", 0, "api_base"]],
[["openai", "chatGPTConfigs", "model"], ["llm", 0, "model_config", "model"]],
[["openai", "chatGPTConfigs", "max_tokens"], ["llm", 0, "model_config", "max_tokens"]],
[["openai", "chatGPTConfigs", "temperature"], ["llm", 0, "model_config", "temperature"]],
[["openai", "chatGPTConfigs", "top_p"], ["llm", 0, "model_config", "top_p"]],
[["openai", "chatGPTConfigs", "frequency_penalty"], ["llm", 0, "model_config", "frequency_penalty"]],
[["openai", "chatGPTConfigs", "presence_penalty"], ["llm", 0, "model_config", "presence_penalty"]],
[["default_personality_str"], ["llm", 0, "default_personality"]],
[["llm_env_prompt"], ["llm", 0, "prompt_prefix"]],
[["openai_image_generate", "model"], ["llm", 0, "image_generation_model_config", "model"]],
[["openai_image_generate", "size"], ["llm", 0, "image_generation_model_config", "size"]],
[["openai_image_generate", "style"], ["llm", 0, "image_generation_model_config", "style"]],
[["openai_image_generate", "quality"], ["llm", 0, "image_generation_model_config", "quality"]],
[["llm_wake_prefix"], ["llm_settings", "wake_prefix"]],
[["baidu_aip", "enable"], ["content_safety", "baidu_aip", "enable"]],
[["baidu_aip", "app_id"], ["content_safety", "baidu_aip", "app_id"]],
[["baidu_aip", "api_key"], ["content_safety", "baidu_aip", "api_key"]],
[["baidu_aip", "secret_key"], ["content_safety", "baidu_aip", "secret_key"]],
[["qq_pic_mode"], ["t2i"]],
[["dump_history_interval"], ["dump_history_interval"]],
[["other_admins"], ["admins_id"]],
[["http_proxy"], ["http_proxy"]],
[["https_proxy"], ["https_proxy"]],
[["dashboard_username"], ["dashboard", "username"]],
[["dashboard_password"], ["dashboard", "password"]],
[["nick_qq"], ["wake_prefix"]],
]
CONFIG_METADATA_2 = {
"config_version": {"description": "配置版本", "type": "int"},
"platform": {
"description": "平台配置",
"type": "list",
"items": {
"name": {"description": "平台名称", "type": "string"},
"enable": {"description": "是否启用", "type": "bool"},
"appid": {"description": "应用ID", "type": "string"},
"secret": {"description": "应用密钥", "type": "string"},
"enable_group_c2c": {"description": "启用群C2C", "type": "bool"},
"enable_guild_direct_message": {"description": "启用公会直接消息", "type": "bool"},
"host": {"description": "主机地址", "type": "string"},
"port": {"description": "端口", "type": "int"},
"websocket_port": {"description": "WebSocket端口", "type": "int"},
"ws_reverse_host": {"description": "WebSocket反向主机", "type": "string"},
"ws_reverse_port": {"description": "WebSocket反向端口", "type": "int"},
"enable_group": {"description": "启用群组", "type": "bool"},
"enable_guild": {"description": "启用公会", "type": "bool"},
"enable_direct_message": {"description": "启用直接消息", "type": "bool"},
"enable_group_increase": {"description": "启用群组增加", "type": "bool"},
}
},
"platform_settings": {
"description": "平台设置",
"type": "object",
"items": {
"unique_session": {"description": "唯一会话", "type": "bool"},
"welcome_message_when_join": {"description": "加入时欢迎信息", "type": "string"},
"rate_limit": {
"description": "速率限制",
"type": "object",
"items": {
"time": {"description": "时间", "type": "int"},
"count": {"description": "计数", "type": "int"},
}
},
"reply_prefix": {"description": "回复前缀", "type": "string"},
"forward_threshold": {"description": "转发消息的阈值", "type": "int"},
}
},
"llm": {
"description": "大语言模型配置",
"type": "list",
"items": {
"name": {"description": "模型名称", "type": "string"},
"enable": {"description": "是否启用", "type": "bool"},
"key": {"description": "密钥", "type": "list", "items": {"type": "string"}},
"api_base": {"description": "API基础URL", "type": "string"},
"prompt_prefix": {"description": "提示前缀", "type": "string"},
"default_personality": {"description": "默认个性", "type": "string"},
"model_config": {
"description": "模型配置",
"type": "object",
"items": {
"model": {"description": "模型名称", "type": "string"},
"max_tokens": {"description": "最大令牌数", "type": "int"},
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
"frequency_penalty": {"description": "频率惩罚", "type": "float"},
"presence_penalty": {"description": "存在惩罚", "type": "float"},
}
},
"image_generation_model_config": {
"description": "图像生成模型配置",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"model": {"description": "模型名称", "type": "string"},
"size": {"description": "图像尺寸", "type": "string"},
"style": {"description": "图像风格", "type": "string"},
"quality": {"description": "图像质量", "type": "string"},
}
},
}
},
"llm_settings": {
"description": "大语言模型设置",
"type": "object",
"items": {
"wake_prefix": {"description": "唤醒前缀", "type": "string"},
"web_search": {"description": "启用网络搜索", "type": "bool"},
}
},
"content_safety": {
"description": "内容安全配置",
"type": "object",
"items": {
"baidu_aip": {
"description": "百度AI平台配置",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"app_id": {"description": "应用ID", "type": "string"},
"api_key": {"description": "API密钥", "type": "string"},
"secret_key": {"description": "秘密密钥", "type": "string"},
}
},
"internal_keywords": {
"description": "内部关键词过滤",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"extra_keywords": {"description": "额外关键词", "type": "list", "items": {"type": "string"}},
}
}
}
},
"wake_prefix": {"description": "唤醒前缀列表", "type": "list", "items": {"type": "string"}},
"t2i": {"description": "文本转图像功能", "type": "bool"},
"dump_history_interval": {"description": "历史记录转储间隔", "type": "int"},
"admins_id": {"description": "管理员ID列表", "type": "list", "items": {"type": "int"}},
"https_proxy": {"description": "HTTPS代理", "type": "string"},
"http_proxy": {"description": "HTTP代理", "type": "string"},
"dashboard": {
"description": "仪表盘配置",
"type": "object",
"items": {
"enable": {"description": "是否启用", "type": "bool"},
"username": {"description": "用户名", "type": "string"},
"password": {"description": "密码", "type": "string"},
}
},
}
+9 -9
View File
@@ -3,7 +3,7 @@ from asyncio import Task
from type.register import *
from typing import List, Awaitable
from logging import Logger
from util.cmd_config import CmdConfig
from util.cmd_config import AstrBotConfig
from util.t2i.renderer import TextToImageRenderer
from util.updator.astrbot_updator import AstrBotUpdator
from util.image_uploader import ImageUploader
@@ -21,19 +21,19 @@ class Context:
'''
def __init__(self):
self.running = True
self.logger: Logger = None
self.base_config: dict = None # 配置(期望启动机器人后是不变的)
self.config_helper: CmdConfig = None
self.config_helper: AstrBotConfig = None
self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件
self.platforms: List[RegisteredPlatform] = []
self.llms: List[RegisteredLLM] = []
self.default_personality: dict = None
self.unique_session = False # 独立会话
self.version: str = None # 机器人版本
self.nick: tuple = None # gocq 的唤醒词
self.t2i_mode = False
self.web_search = False # 是否开启了网页搜索
# self.unique_session = False # 独立会话
# self.version: str = None # 机器人版本
# self.nick: tuple = None # gocq 的唤醒词
# self.t2i_mode = False
# self.web_search = False # 是否开启了网页搜索
self.metrics_uploader = None
self.updator: AstrBotUpdator = None
@@ -48,7 +48,7 @@ class Context:
self.running = True
# useless
self.reply_prefix = ""
# self.reply_prefix = ""
def register_commands(self,
plugin_name: str,
+250 -46
View File
@@ -1,33 +1,242 @@
import os
import json
from type.config import DEFAULT_CONFIG
import logging
from type.config import DEFAULT_CONFIG, DEFAULT_CONFIG_VERSION_2, MAPPINGS_1_2
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional
cpath = "data/cmd_config.json"
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
logger = logging.getLogger("astrbot")
def check_exist():
if not os.path.exists(cpath):
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump({}, f, ensure_ascii=False)
f.flush()
@dataclass
class RateLimit:
time: int = 60
count: int = 30
@dataclass
class PlatformSettings:
unique_session: bool = False
welcome_message_when_join: str = ""
rate_limit: RateLimit = field(default_factory=RateLimit)
reply_prefix: str = ""
forward_threshold: int = 200
def __post_init__(self):
self.rate_limit = RateLimit(**self.rate_limit)
@dataclass
class PlatformConfig():
id: str = ""
name: str = ""
enable: bool = False
@dataclass
class QQOfficialPlatformConfig(PlatformConfig):
appid: str = ""
secret: str = ""
enable_group_c2c: bool = True
enable_guild_direct_message: bool = True
@dataclass
class NakuruPlatformConfig(PlatformConfig):
host: str = "172.0.0.1",
port: int = 5700,
websocket_port: int = 6700,
enable_group: bool = True,
enable_guild: bool = True,
enable_direct_message: bool = True,
enable_group_increase: bool = True
@dataclass
class AiocqhttpPlatformConfig(PlatformConfig):
ws_reverse_host: str = ""
ws_reverse_port: int = 6199
@dataclass
class ModelConfig:
model: str = "gpt-4o"
max_tokens: int = 6000
temperature: float = 0.9
top_p: float = 1
frequency_penalty: float = 0
presence_penalty: float = 0
@dataclass
class ImageGenerationModelConfig:
enable: bool = True
model: str = "dall-e-3"
size: str = "1024x1024"
style: str = "vivid"
quality: str = "standard"
@dataclass
class LLMConfig:
id: str = ""
name: str = "openai"
enable: bool = True
key: List[str] = field(default_factory=list)
api_base: str = ""
prompt_prefix: str = ""
default_personality: str = ""
model_config: ModelConfig = field(default_factory=ModelConfig)
image_generation_model_config: Optional[ImageGenerationModelConfig] = None
def __post_init__(self):
self.model_config = ModelConfig(**self.model_config)
if self.image_generation_model_config:
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config)
@dataclass
class LLMSettings:
wake_prefix: str = ""
web_search: bool = False
@dataclass
class BaiduAIPConfig:
enable: bool = False
app_id: str = ""
api_key: str = ""
secret_key: str = ""
@dataclass
class InternalKeywordsConfig:
enable: bool = True
extra_keywords: List[str] = field(default_factory=list)
@dataclass
class ContentSafetyConfig:
baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig)
internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig)
def __post_init__(self):
self.baidu_aip = BaiduAIPConfig(**self.baidu_aip)
self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords)
@dataclass
class DashboardConfig:
enable: bool = True
username: str = ""
password: str = ""
@dataclass
class AstrBotConfig():
config_version: int = 2
platform_settings: PlatformSettings = field(default_factory=PlatformSettings)
llm: List[LLMConfig] = field(default_factory=list)
llm_settings: LLMSettings = field(default_factory=LLMSettings)
content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig)
t2i: bool = True
dump_history_interval: int = 10
admins_id: List[str] = field(default_factory=list)
https_proxy: str = ""
http_proxy: str = ""
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
platform: List[PlatformConfig] = field(default_factory=list)
wake_prefix: List[str] = field(default_factory=list)
class CmdConfig():
def __init__(self) -> None:
self.cached_config: dict = {}
self.init_configs()
# compability
if isinstance(self.wake_prefix, str):
self.wake_prefix = [self.wake_prefix]
def load_from_dict(self, data: Dict):
'''从字典中加载配置到对象。
@note: 适用于 version 2 配置文件。
'''
self.config_version=data.get("version", 2)
self.platform=[]
for p in data.get("platform", []):
if 'name' not in p:
logger.warning("A platform config missing name, skipping.")
continue
if p["name"] == "qq_official":
self.platform.append(QQOfficialPlatformConfig(**p))
elif p["name"] == "nakuru":
self.platform.append(NakuruPlatformConfig(**p))
elif p["name"] == "aiocqhttp":
self.platform.append(AiocqhttpPlatformConfig(**p))
else:
self.platform.append(PlatformConfig(**p))
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
self.content_safety=ContentSafetyConfig(**data.get("content_safety", {}))
self.t2i=data.get("t2i", True)
self.dump_history_interval=data.get("dump_history_interval", 10)
self.admins_id=data.get("admins_id", [])
self.https_proxy=data.get("https_proxy", "")
self.http_proxy=data.get("http_proxy", "")
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
self.wake_prefix=data.get("wake_prefix", [])
def migrate_config_1_2(self, old: dict) -> dict:
'''将配置文件从版本 1 迁移至版本 2'''
logger.info("正在更新配置文件到 version 2...")
new_config = DEFAULT_CONFIG_VERSION_2
mappings = MAPPINGS_1_2
def set_nested_value(d, keys, value):
cursor = d
for key in keys[:-1]:
cursor = cursor[key]
cursor[keys[-1]] = value
for old_path, new_path in mappings:
value = old
try:
for key in old_path:
value = value[key] # soooooo convenient!!
set_nested_value(new_config, new_path, value)
except KeyError:
# 如果旧配置中没有这个键,跳过,即使用新配置的默认值
continue
logger.info("配置文件更新完成。")
return new_config
def flush_config(self, config: dict = None):
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False)
f.flush()
def save_config(self):
'''将现存配置写入文件'''
self.flush_config(self.to_dict())
def init_configs(self):
'''
初始化必需的配置项
'''
self.init_config_items(DEFAULT_CONFIG)
'''初始化必需的配置项'''
config = None
if not self.check_exist():
self.flush_config()
config = DEFAULT_CONFIG_VERSION_2
else:
config = self.get_all()
# check if the config is outdated
if 'config_version' not in config: # version 1
config = self.migrate_config_1_2(config)
self.flush_config(config)
_tag = False
for key, val in DEFAULT_CONFIG_VERSION_2.items():
if key not in config:
config[key] = val
_tag = True
if _tag:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
f.flush()
@staticmethod
def get(key, default=None):
self.load_from_dict(config)
def get(self, key: str, default=None):
'''
从文件系统中直接获取配置
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
if key in d:
return d[key]
@@ -38,30 +247,29 @@ class CmdConfig():
'''
从文件系统中获取所有配置
'''
check_exist()
with open(cpath, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
if not conf_str:
return {}
conf = json.loads(conf_str)
return conf
def put(self, key, value):
with open(cpath, "r", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
d[key] = value
with open(cpath, "w", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
self.cached_config[key] = value
@staticmethod
def put_by_dot_str(key: str, value):
'''
根据点分割的字符串,将值写入配置文件
'''
with open(cpath, "r", encoding="utf-8-sig") as f:
def to_dict(self) -> Dict:
return asdict(self)
def put_by_dot_str(self, key: str, value):
'''根据点分割的字符串,将值写入配置文件'''
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
d = json.load(f)
_d = d
_ks = key.split(".")
@@ -70,23 +278,19 @@ class CmdConfig():
_d[_ks[i]] = value
else:
_d = _d[_ks[i]]
with open(cpath, "w", encoding="utf-8-sig") as f:
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(d, f, indent=2, ensure_ascii=False)
f.flush()
def init_config_items(self, d: dict):
conf = self.get_all()
def update_by_path(self, path: List):
'''根据路径更新配置文件。
if not self.cached_config:
self.cached_config = conf
这个方法首先会更新缓存在内存中的配置,然后再写入文件。
'''
for key in path:
if key not in self:
raise KeyError(f"Key {key} not found in config.")
_tag = False
for key, val in d.items():
if key not in conf:
conf[key] = val
_tag = True
if _tag:
with open(cpath, "w", encoding="utf-8-sig") as f:
json.dump(conf, f, indent=2, ensure_ascii=False)
f.flush()
def check_exist(self) -> bool:
return os.path.exists(ASTRBOT_CONFIG_PATH)
+9 -10
View File
@@ -1,16 +1,15 @@
import json, os
from util.cmd_config import CmdConfig
import json, os, shutil
import logging
logger = logging.getLogger("astrbot")
def try_migrate_config():
'''
将 cmd_config.json 迁移至 data/cmd_config.json (如果存在的话)
'''
if os.path.exists("cmd_config.json"):
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
data = json.load(f)
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
if os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"):
try:
os.remove("cmd_config.json")
except Exception as e:
pass
shutil.move("cmd_config.json", "data/cmd_config.json")
except:
logger.error("迁移 cmd_config.json 失败。AstrBot 将不会读取配置文件,你可以手动将 cmd_config.json 迁移至 data/cmd_config.json。")
+2 -1
View File
@@ -5,6 +5,7 @@ import sys
from type.types import Context
from collections import defaultdict
from type.config import VERSION
class MetricUploader():
def __init__(self, context: Context) -> None:
@@ -49,7 +50,7 @@ class MetricUploader():
try:
res = {
"stat_version": "moon",
"version": context.version, # 版本号
"version": VERSION, # 版本号
"platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数
"llm_stats": self.llm_stats,
"plugin_stats": self.plugin_stats,
+1 -1
View File
@@ -7,6 +7,6 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。
from model.platform import Platform
from model.platform.qq_nakuru import QQGOCQ
from model.platform.qq_nakuru import QQNakuru
from model.platform.qq_official import QQOfficial
from model.platform.qq_aiocqhttp import AIOCQHTTP