Merge pull request #184 from Soulter/config-refactor
更易读的配置格式和平台、LLM多实例
This commit is contained in:
@@ -26,7 +26,11 @@ jobs:
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
run: |
|
||||
export LLM_MODEL=${{ secrets.LLM_MODEL }}
|
||||
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
|
||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
|
||||
+2
-1
@@ -10,4 +10,5 @@ cmd_config.json
|
||||
data/*
|
||||
cookies.json
|
||||
logs/
|
||||
addons/plugins
|
||||
addons/plugins
|
||||
.coverage
|
||||
+24
-29
@@ -13,7 +13,7 @@ from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.cmd_config import AstrBotConfig
|
||||
from util.metrics import MetricUploader
|
||||
from util.config_utils import *
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
@@ -24,29 +24,15 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class AstrBotBootstrap():
|
||||
def __init__(self) -> None:
|
||||
self.context = Context()
|
||||
self.config_helper = CmdConfig()
|
||||
|
||||
# load configs and ensure the backward compatibility
|
||||
try_migrate_config()
|
||||
self.config_helper = AstrBotConfig()
|
||||
self.context.config_helper = self.config_helper
|
||||
self.context.base_config = self.config_helper.cached_config
|
||||
|
||||
self.context.default_personality = {
|
||||
"name": "default",
|
||||
"prompt": self.context.base_config.get("default_personality_str", ""),
|
||||
}
|
||||
self.context.unique_session = self.context.base_config.get("uniqueSessionMode", False)
|
||||
nick_qq = self.context.base_config.get("nick_qq", ('/', '!'))
|
||||
if isinstance(nick_qq, str): nick_qq = (nick_qq, )
|
||||
self.context.nick = nick_qq
|
||||
self.context.t2i_mode = self.context.base_config.get("qq_pic_mode", True)
|
||||
self.context.version = VERSION
|
||||
|
||||
logger.info("AstrBot v" + self.context.version)
|
||||
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
# apply proxy settings
|
||||
http_proxy = self.context.base_config.get("http_proxy")
|
||||
https_proxy = self.context.base_config.get("https_proxy")
|
||||
http_proxy = self.context.config_helper.http_proxy
|
||||
https_proxy = self.context.config_helper.https_proxy
|
||||
if http_proxy:
|
||||
os.environ['HTTP_PROXY'] = http_proxy
|
||||
if https_proxy:
|
||||
@@ -68,10 +54,9 @@ class AstrBotBootstrap():
|
||||
self.db_conn_helper = dbConn()
|
||||
|
||||
# load llm provider
|
||||
self.llm_instance: Provider = None
|
||||
self.load_llm()
|
||||
|
||||
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper, self.llm_instance)
|
||||
self.message_handler = MessageHandler(self.context, self.command_manager, self.db_conn_helper)
|
||||
self.platfrom_manager = PlatformManager(self.context, self.message_handler)
|
||||
self.dashboard = AstrBotDashBoard(self.context, plugin_manager=self.plugin_manager, astrbot_updator=self.updator)
|
||||
self.metrics_uploader = MetricUploader(self.context)
|
||||
@@ -114,16 +99,26 @@ class AstrBotBootstrap():
|
||||
return
|
||||
|
||||
def load_llm(self):
|
||||
if 'openai' in self.config_helper.cached_config and \
|
||||
len(self.config_helper.cached_config['openai']['key']) and \
|
||||
self.config_helper.cached_config['openai']['key'][0] is not None:
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
f = False
|
||||
llms = self.context.config_helper.llm
|
||||
logger.info(f"加载 {len(llms)} 个 LLM Provider...")
|
||||
for llm in llms:
|
||||
if llm.enable:
|
||||
if llm.name == "openai" and llm.key and llm.enable:
|
||||
self.load_openai(llm)
|
||||
f = True
|
||||
logger.info(f"已启用 OpenAI API 支持。")
|
||||
else:
|
||||
logger.warn(f"未知的 LLM Provider: {llm.name}")
|
||||
if f:
|
||||
from model.command.openai_official_handler import OpenAIOfficialCommandHandler
|
||||
self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager)
|
||||
self.llm_instance = ProviderOpenAIOfficial(self.context)
|
||||
self.openai_command_handler.set_provider(self.llm_instance)
|
||||
self.context.register_provider("internal_openai", self.llm_instance)
|
||||
logger.info("已启用 OpenAI API 支持。")
|
||||
self.openai_command_handler.set_provider(self.context.llms[0].llm_instance)
|
||||
|
||||
def load_openai(self, llm_config):
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
inst = ProviderOpenAIOfficial(llm_config)
|
||||
self.context.register_provider("internal_openai", inst)
|
||||
|
||||
def load_plugins(self):
|
||||
self.plugin_manager.plugin_reload()
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
from aip import AipContentCensor
|
||||
from util.cmd_config import BaiduAIPConfig
|
||||
|
||||
|
||||
class BaiduJudge:
|
||||
def __init__(self, baidu_configs) -> None:
|
||||
if 'app_id' in baidu_configs and 'api_key' in baidu_configs and 'secret_key' in baidu_configs:
|
||||
self.app_id = str(baidu_configs['app_id'])
|
||||
self.api_key = baidu_configs['api_key']
|
||||
self.secret_key = baidu_configs['secret_key']
|
||||
self.client = AipContentCensor(
|
||||
self.app_id, self.api_key, self.secret_key)
|
||||
else:
|
||||
raise ValueError("Baidu configs error! 请填写百度内容审核服务相关配置!")
|
||||
def __init__(self, baidu_configs: BaiduAIPConfig) -> None:
|
||||
self.app_id = baidu_configs.app_id
|
||||
self.api_key = baidu_configs.api_key
|
||||
self.secret_key = baidu_configs.secret_key
|
||||
self.client = AipContentCensor(self.app_id,
|
||||
self.api_key,
|
||||
self.secret_key)
|
||||
|
||||
def judge(self, text):
|
||||
res = self.client.textCensorUserDefined(text)
|
||||
|
||||
+13
-28
@@ -25,16 +25,11 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class RateLimitHelper():
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.user_rate_limit: Dict[int, int] = {}
|
||||
self.rate_limit_time: int = 60
|
||||
self.rate_limit_count: int = 10
|
||||
rl = context.config_helper.platform_settings.rate_limit
|
||||
self.rate_limit_time: int = rl.time
|
||||
self.rate_limit_count: int = rl.count
|
||||
self.user_frequency = {}
|
||||
|
||||
if 'limit' in context.base_config:
|
||||
if 'count' in context.base_config['limit']:
|
||||
self.rate_limit_count = context.base_config['limit']['count']
|
||||
if 'time' in context.base_config['limit']:
|
||||
self.rate_limit_time = context.base_config['limit']['time']
|
||||
|
||||
|
||||
def check_frequency(self, session_id: str) -> bool:
|
||||
'''
|
||||
检查发言频率
|
||||
@@ -59,12 +54,11 @@ class RateLimitHelper():
|
||||
class ContentSafetyHelper():
|
||||
def __init__(self, context: Context) -> None:
|
||||
self.baidu_judge = None
|
||||
if 'baidu_api' in context.base_config and \
|
||||
'enable' in context.base_config['baidu_aip'] and \
|
||||
context.base_config['baidu_aip']['enable']:
|
||||
aip = context.config_helper.content_safety.baidu_aip
|
||||
if aip.enable:
|
||||
try:
|
||||
from astrbot.message.baidu_aip_judge import BaiduJudge
|
||||
self.baidu_judge = BaiduJudge(context.base_config['baidu_aip'])
|
||||
self.baidu_judge = BaiduJudge(aip)
|
||||
logger.info("已启用百度 AI 内容审核。")
|
||||
except BaseException as e:
|
||||
logger.error("百度 AI 内容审核初始化失败。")
|
||||
@@ -107,22 +101,19 @@ class ContentSafetyHelper():
|
||||
class MessageHandler():
|
||||
def __init__(self, context: Context,
|
||||
command_manager: CommandManager,
|
||||
persist_manager: dbConn,
|
||||
provider: Provider) -> None:
|
||||
persist_manager: dbConn) -> None:
|
||||
self.context = context
|
||||
self.command_manager = command_manager
|
||||
self.persist_manager = persist_manager
|
||||
self.rate_limit_helper = RateLimitHelper(context)
|
||||
self.content_safety_helper = ContentSafetyHelper(context)
|
||||
self.llm_wake_prefix = self.context.base_config['llm_wake_prefix']
|
||||
self.llm_wake_prefix = self.context.config_helper.llm_settings.wake_prefix
|
||||
if self.llm_wake_prefix:
|
||||
self.llm_wake_prefix = self.llm_wake_prefix.strip()
|
||||
self.nicks = self.context.nick
|
||||
self.provider = provider
|
||||
self.reply_prefix = str(self.context.reply_prefix)
|
||||
|
||||
self.provider = self.context.llms[0].llm_instance if len(self.context.llms) > 0 else None
|
||||
self.reply_prefix = str(self.context.config_helper.platform_settings.reply_prefix)
|
||||
self.llm_tools = FuncCall(self.provider)
|
||||
|
||||
|
||||
def set_provider(self, provider: Provider):
|
||||
self.provider = provider
|
||||
|
||||
@@ -148,7 +139,7 @@ class MessageHandler():
|
||||
return
|
||||
|
||||
# remove the nick prefix
|
||||
for nick in self.nicks:
|
||||
for nick in self.context.config_helper.wake_prefix:
|
||||
if msg_plain.startswith(nick):
|
||||
msg_plain = msg_plain.removeprefix(nick)
|
||||
break
|
||||
@@ -187,12 +178,6 @@ class MessageHandler():
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
|
||||
# web_search = self.context.web_search
|
||||
# if not web_search and msg_plain.startswith("ws"):
|
||||
# # leverage web search feature
|
||||
# web_search = True
|
||||
# msg_plain = msg_plain.removeprefix("ws").strip()
|
||||
try:
|
||||
if not self.llm_tools.empty():
|
||||
# tools-use
|
||||
|
||||
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
||||
|
||||
class DashBoardData():
|
||||
stats: dict = {}
|
||||
configs: dict = {}
|
||||
|
||||
@dataclass
|
||||
class Response():
|
||||
|
||||
+57
-521
@@ -1,537 +1,73 @@
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
from . import DashBoardData
|
||||
from typing import Union, Optional
|
||||
from util.cmd_config import CmdConfig
|
||||
from dataclasses import dataclass
|
||||
from util.cmd_config import AstrBotConfig
|
||||
from dataclasses import dataclass, asdict
|
||||
from util.plugin_dev.api.v1.config import update_config
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from type.types import Context
|
||||
from type.config import CONFIG_METADATA_2
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
|
||||
@dataclass
|
||||
class DashBoardConfig():
|
||||
config_type: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
path: Optional[str] = None # 仅 item 才需要
|
||||
body: Optional[list['DashBoardConfig']] = None # 仅 group 才需要
|
||||
value: Optional[Union[list, dict, str, int, bool]] = None # 仅 item 才需要
|
||||
val_type: Optional[str] = None # 仅 item 才需要
|
||||
|
||||
|
||||
class DashBoardHelper():
|
||||
def __init__(self, context: Context, dashboard_data: DashBoardData):
|
||||
dashboard_data.configs = {
|
||||
"data": []
|
||||
}
|
||||
def __init__(self, context: Context):
|
||||
self.context = context
|
||||
self.parse_default_config(dashboard_data, context.base_config)
|
||||
self.config_key_dont_show = ['dashboard', 'config_version']
|
||||
|
||||
# 将 config.yaml、 中的配置解析到 dashboard_data.configs 中
|
||||
def parse_default_config(self, dashboard_data: DashBoardData, config: dict):
|
||||
|
||||
try:
|
||||
qq_official_platform_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="QQ(官方)",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启用 QQ_OFFICIAL 平台",
|
||||
description="官方的接口,仅支持 QQ 频道。详见 q.qq.com",
|
||||
value=config['qqbot']['enable'],
|
||||
path="qqbot.enable",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="QQ机器人APPID",
|
||||
description="详见 q.qq.com",
|
||||
value=config['qqbot']['appid'],
|
||||
path="qqbot.appid",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="QQ机器人令牌",
|
||||
description="详见 q.qq.com",
|
||||
value=config['qqbot']['token'],
|
||||
path="qqbot.token",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="QQ机器人 Secret",
|
||||
description="详见 q.qq.com",
|
||||
value=config['qqbot_secret'],
|
||||
path="qqbot_secret",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否允许 QQ 频道私聊",
|
||||
description="如果启用,机器人会响应私聊消息。",
|
||||
value=config['direct_message_mode'],
|
||||
path="direct_message_mode",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否接收QQ群消息",
|
||||
description="需要机器人有相应的群消息接收权限。在 q.qq.com 上查看。",
|
||||
value=config['qqofficial_enable_group_message'],
|
||||
path="qqofficial_enable_group_message",
|
||||
),
|
||||
]
|
||||
)
|
||||
qq_gocq_platform_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="QQ(nakuru)",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启用",
|
||||
description="",
|
||||
value=config['gocqbot']['enable'],
|
||||
path="gocqbot.enable",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="HTTP 服务器地址",
|
||||
description="",
|
||||
value=config['gocq_host'],
|
||||
path="gocq_host",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="HTTP 服务器端口",
|
||||
description="",
|
||||
value=config['gocq_http_port'],
|
||||
path="gocq_http_port",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="WebSocket 服务器端口",
|
||||
description="目前仅支持正向 WebSocket",
|
||||
value=config['gocq_websocket_port'],
|
||||
path="gocq_websocket_port",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应群消息",
|
||||
description="",
|
||||
value=config['gocq_react_group'],
|
||||
path="gocq_react_group",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应私聊消息",
|
||||
description="",
|
||||
value=config['gocq_react_friend'],
|
||||
path="gocq_react_friend",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应群成员增加消息",
|
||||
description="",
|
||||
value=config['gocq_react_group_increase'],
|
||||
path="gocq_react_group_increase",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="是否响应频道消息",
|
||||
description="",
|
||||
value=config['gocq_react_guild'],
|
||||
path="gocq_react_guild",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="转发阈值(字符数)",
|
||||
description="机器人回复的消息长度超出这个值后,会被折叠成转发卡片发出以减少刷屏。",
|
||||
value=config['qq_forward_threshold'],
|
||||
path="qq_forward_threshold",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
qq_aiocqhttp_platform_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="QQ(aiocqhttp)",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启用",
|
||||
description="",
|
||||
value=config['aiocqhttp']['enable'],
|
||||
path="aiocqhttp.enable",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="WebSocket 反向连接 host",
|
||||
description="",
|
||||
value=config['aiocqhttp']['ws_reverse_host'],
|
||||
path="aiocqhttp.ws_reverse_host",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="WebSocket 反向连接 port",
|
||||
description="",
|
||||
value=config['aiocqhttp']['ws_reverse_port'],
|
||||
path="aiocqhttp.ws_reverse_port",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
general_platform_detail_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="通用平台配置",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启动消息文字转图片",
|
||||
description="启动后,机器人会将消息转换为图片发送,以降低风控风险。",
|
||||
value=config['qq_pic_mode'],
|
||||
path="qq_pic_mode",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="消息限制时间",
|
||||
description="在此时间内,机器人不会回复同一个用户的消息。单位:秒",
|
||||
value=config['limit']['time'],
|
||||
path="limit.time",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="消息限制次数",
|
||||
description="在上面的时间内,如果用户发送消息超过此次数,则机器人不会回复。单位:次",
|
||||
value=config['limit']['count'],
|
||||
path="limit.count",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="回复前缀",
|
||||
description="[xxxx] 你好! 其中xxxx是你可以填写的前缀。如果为空则不显示。",
|
||||
value=config['reply_prefix'],
|
||||
path="reply_prefix",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="list",
|
||||
name="通用管理员用户 ID(支持多个管理员)。通过 !myid 指令获取。",
|
||||
description="",
|
||||
value=config['other_admins'],
|
||||
path="other_admins",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="独立会话",
|
||||
description="是否启用独立会话模式,即 1 个用户自然账号 1 个会话。",
|
||||
value=config['uniqueSessionMode'],
|
||||
path="uniqueSessionMode",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="LLM 唤醒词",
|
||||
description="如果不为空, 那么只有当消息以此词开头时,才会调用大语言模型进行回复。如设置为 /chat,那么只有当消息以 /chat 开头时,才会调用大语言模型进行回复。",
|
||||
value=config['llm_wake_prefix'],
|
||||
path="llm_wake_prefix",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
openai_official_llm_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="OpenAI 官方接口类设置",
|
||||
description="",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="list",
|
||||
name="OpenAI API Key",
|
||||
description="OpenAI API 的 Key。支持使用非官方但兼容的 API(第三方中转key)。",
|
||||
value=config['openai']['key'],
|
||||
path="openai.key",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI API 节点地址(api base)",
|
||||
description="OpenAI API 的节点地址,配合非官方 API 使用。如果不想填写,那么请填写 none",
|
||||
value=config['openai']['api_base'],
|
||||
path="openai.api_base",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI model",
|
||||
description="OpenAI LLM 模型。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['model'],
|
||||
path="openai.chatGPTConfigs.model",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="OpenAI max_tokens",
|
||||
description="OpenAI 最大生成长度。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['max_tokens'],
|
||||
path="openai.chatGPTConfigs.max_tokens",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI temperature",
|
||||
description="OpenAI 温度。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['temperature'],
|
||||
path="openai.chatGPTConfigs.temperature",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI top_p",
|
||||
description="OpenAI top_p。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['top_p'],
|
||||
path="openai.chatGPTConfigs.top_p",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI frequency_penalty",
|
||||
description="OpenAI frequency_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['frequency_penalty'],
|
||||
path="openai.chatGPTConfigs.frequency_penalty",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="float",
|
||||
name="OpenAI presence_penalty",
|
||||
description="OpenAI presence_penalty。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['chatGPTConfigs']['presence_penalty'],
|
||||
path="openai.chatGPTConfigs.presence_penalty",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="int",
|
||||
name="OpenAI 总生成长度限制",
|
||||
description="OpenAI 总生成长度限制。详见 https://platform.openai.com/docs/api-reference/chat",
|
||||
value=config['openai']['total_tokens_limit'],
|
||||
path="openai.total_tokens_limit",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成模型",
|
||||
description="OpenAI 图像生成模型。",
|
||||
value=config['openai_image_generate']['model'],
|
||||
path="openai_image_generate.model",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成大小",
|
||||
description="OpenAI 图像生成大小。",
|
||||
value=config['openai_image_generate']['size'],
|
||||
path="openai_image_generate.size",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成风格",
|
||||
description="OpenAI 图像生成风格。修改前请参考 OpenAI 官方文档",
|
||||
value=config['openai_image_generate']['style'],
|
||||
path="openai_image_generate.style",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="OpenAI 图像生成质量",
|
||||
description="OpenAI 图像生成质量。修改前请参考 OpenAI 官方文档",
|
||||
value=config['openai_image_generate']['quality'],
|
||||
path="openai_image_generate.quality",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="问题题首提示词",
|
||||
description="如果填写了此项,在每个对大语言模型的请求中,都会在问题前加上此提示词。",
|
||||
value=config['llm_env_prompt'],
|
||||
path="llm_env_prompt",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="默认人格文本",
|
||||
description="默认人格文本",
|
||||
value=config['default_personality_str'],
|
||||
path="default_personality_str",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
baidu_aip_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="百度内容审核",
|
||||
description="需要去申请",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="bool",
|
||||
name="启动百度内容审核服务",
|
||||
description="",
|
||||
value=config['baidu_aip']['enable'],
|
||||
path="baidu_aip.enable"
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="APP ID",
|
||||
description="",
|
||||
value=config['baidu_aip']['app_id'],
|
||||
path="baidu_aip.app_id"
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="API KEY",
|
||||
description="",
|
||||
value=config['baidu_aip']['api_key'],
|
||||
path="baidu_aip.api_key"
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="SECRET KEY",
|
||||
description="",
|
||||
value=config['baidu_aip']['secret_key'],
|
||||
path="baidu_aip.secret_key"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
other_group = DashBoardConfig(
|
||||
config_type="group",
|
||||
name="其他配置",
|
||||
description="其他配置描述",
|
||||
body=[
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="HTTP 代理地址",
|
||||
description="建议上下一致",
|
||||
value=config['http_proxy'],
|
||||
path="http_proxy",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="HTTPS 代理地址",
|
||||
description="建议上下一致",
|
||||
value=config['https_proxy'],
|
||||
path="https_proxy",
|
||||
),
|
||||
DashBoardConfig(
|
||||
config_type="item",
|
||||
val_type="str",
|
||||
name="面板用户名",
|
||||
description="是的,就是你理解的这个面板的用户名",
|
||||
value=config['dashboard_username'],
|
||||
path="dashboard_username",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
dashboard_data.configs['data'] = [
|
||||
qq_official_platform_group,
|
||||
qq_gocq_platform_group,
|
||||
general_platform_detail_group,
|
||||
openai_official_llm_group,
|
||||
other_group,
|
||||
baidu_aip_group,
|
||||
qq_aiocqhttp_platform_group
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"配置文件解析错误:{e}")
|
||||
raise e
|
||||
|
||||
def save_config(self, post_config: list, namespace: str):
|
||||
'''
|
||||
根据 path 解析并保存配置
|
||||
'''
|
||||
|
||||
queue = post_config
|
||||
while len(queue) > 0:
|
||||
config = queue.pop(0)
|
||||
if config['config_type'] == "group":
|
||||
for item in config['body']:
|
||||
queue.append(item)
|
||||
elif config['config_type'] == "item":
|
||||
if config['path'] is None or config['path'] == "":
|
||||
def validate_config(self, data):
|
||||
errors = []
|
||||
# 递归验证数据
|
||||
def validate(data, path=""):
|
||||
for key, meta in CONFIG_METADATA_2.items():
|
||||
if key not in data:
|
||||
if key not in self.config_key_dont_show:
|
||||
# 这些key不会传给前端,所以不需要验证
|
||||
errors.append(f"Missing key: {path}{key}")
|
||||
continue
|
||||
value = data[key]
|
||||
if meta["type"] == "int" and not isinstance(value, int):
|
||||
errors.append(f"Invalid type for {path}{key}: expected int, got {type(value).__name__}")
|
||||
elif meta["type"] == "bool" and not isinstance(value, bool):
|
||||
errors.append(f"Invalid type for {path}{key}: expected bool, got {type(value).__name__}")
|
||||
elif meta["type"] == "string" and not isinstance(value, str):
|
||||
errors.append(f"Invalid type for {path}{key}: expected string, got {type(value).__name__}")
|
||||
elif meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(f"Invalid type for {path}{key}: expected list, got {type(value).__name__}")
|
||||
for item in value:
|
||||
validate(item, meta["items"], path=f"{path}{key}.")
|
||||
elif meta["type"] == "dict" and not isinstance(value, dict):
|
||||
errors.append(f"Invalid type for {path}{key}: expected dict, got {type(value).__name__}")
|
||||
validate(value, meta["items"], path=f"{path}{key}.")
|
||||
validate(data)
|
||||
|
||||
# hardcode warning
|
||||
data['config_version'] = self.context.config_helper.config_version
|
||||
data['dashboard'] = asdict(self.context.config_helper.dashboard)
|
||||
|
||||
return errors
|
||||
|
||||
path = config['path'].split('.')
|
||||
if len(path) == 0:
|
||||
continue
|
||||
def save_astrbot_config(self, post_config: dict):
|
||||
'''验证并保存配置'''
|
||||
errors = self.validate_config(post_config)
|
||||
if errors:
|
||||
raise ValueError(f"格式校验未通过: {errors}")
|
||||
self.context.config_helper.flush_config(post_config)
|
||||
|
||||
def save_extension_config(self, post_config: dict):
|
||||
if 'namespace' not in post_config:
|
||||
raise ValueError("Missing key: namespace")
|
||||
if 'config' not in post_config:
|
||||
raise ValueError("Missing key: config")
|
||||
|
||||
if config['val_type'] == "bool":
|
||||
self._write_config(
|
||||
namespace, config['path'], config['value'])
|
||||
elif config['val_type'] == "str":
|
||||
self._write_config(
|
||||
namespace, config['path'], config['value'])
|
||||
elif config['val_type'] == "int":
|
||||
try:
|
||||
self._write_config(
|
||||
namespace, config['path'], int(config['value']))
|
||||
except:
|
||||
raise ValueError(f"配置项 {config['name']} 的值必须是整数")
|
||||
elif config['val_type'] == "float":
|
||||
try:
|
||||
self._write_config(
|
||||
namespace, config['path'], float(config['value']))
|
||||
except:
|
||||
raise ValueError(f"配置项 {config['name']} 的值必须是浮点数")
|
||||
elif config['val_type'] == "list":
|
||||
if config['value'] is None:
|
||||
self._write_config(namespace, config['path'], [])
|
||||
elif not isinstance(config['value'], list):
|
||||
raise ValueError(f"配置项 {config['name']} 的值必须是列表")
|
||||
self._write_config(
|
||||
namespace, config['path'], config['value'])
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"未知或者未实现的配置项类型:{config['val_type']}")
|
||||
|
||||
def _write_config(self, namespace: str, key: str, value):
|
||||
if namespace == "" or namespace.startswith("internal_"):
|
||||
# 机器人自带配置,存到 config.yaml
|
||||
self.context.config_helper.put_by_dot_str(key, value)
|
||||
else:
|
||||
namespace = post_config['namespace']
|
||||
config: list = post_config['config'][0]['body']
|
||||
for item in config:
|
||||
key = item['path']
|
||||
value = item['value']
|
||||
typ = item['val_type']
|
||||
if typ == 'int':
|
||||
if not value.isdigit():
|
||||
raise ValueError(f"Invalid type for {namespace}.{key}: expected int, got {type(value).__name__}")
|
||||
value = int(value)
|
||||
update_config(namespace, key, value)
|
||||
|
||||
+75
-126
@@ -19,6 +19,7 @@ from dashboard.helper import DashBoardHelper
|
||||
from util.io import get_local_ip_addresses
|
||||
from model.plugin.manager import PluginManager
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from type.config import CONFIG_METADATA_2
|
||||
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -30,9 +31,10 @@ class AstrBotDashBoard():
|
||||
self.plugin_manager = plugin_manager
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.dashboard_data = DashBoardData()
|
||||
self.dashboard_helper = DashBoardHelper(self.context, self.dashboard_data)
|
||||
self.dashboard_helper = DashBoardHelper(self.context)
|
||||
|
||||
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
|
||||
self.dashboard_be.json.sort_keys=False # 不按照字典排序
|
||||
logging.getLogger('werkzeug').setLevel(logging.ERROR)
|
||||
self.dashboard_be.logger.setLevel(logging.ERROR)
|
||||
|
||||
@@ -68,8 +70,8 @@ class AstrBotDashBoard():
|
||||
|
||||
@self.dashboard_be.post("/api/authenticate")
|
||||
def authenticate():
|
||||
username = self.context.base_config.get("dashboard_username", "")
|
||||
password = self.context.base_config.get("dashboard_password", "")
|
||||
username = self.context.config_helper.dashboard.username
|
||||
password = self.context.config_helper.dashboard.password
|
||||
# 获得请求体
|
||||
post_data = request.json
|
||||
if post_data["username"] == username and post_data["password"] == password:
|
||||
@@ -90,7 +92,7 @@ class AstrBotDashBoard():
|
||||
|
||||
@self.dashboard_be.post("/api/change_password")
|
||||
def change_password():
|
||||
password = self.context.base_config.get("dashboard_password", "")
|
||||
password = self.context.config_helper.dashboard.password
|
||||
# 获得请求体
|
||||
post_data = request.json
|
||||
if post_data["password"] == password:
|
||||
@@ -130,40 +132,53 @@ class AstrBotDashBoard():
|
||||
|
||||
@self.dashboard_be.get("/api/configs")
|
||||
def get_configs():
|
||||
# 如果params中有namespace,则返回该namespace下的配置
|
||||
# 否则返回所有配置
|
||||
# namespace 为空时返回 AstrBot 配置
|
||||
# 否则返回指定 namespace 的插件配置
|
||||
namespace = "" if "namespace" not in request.args else request.args["namespace"]
|
||||
conf = self._get_configs(namespace)
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=conf
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/config_outline")
|
||||
def get_config_outline():
|
||||
outline = self._generate_outline()
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=outline
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/configs")
|
||||
def post_configs():
|
||||
post_configs = request.json
|
||||
try:
|
||||
self.on_post_configs(post_configs)
|
||||
if not namespace:
|
||||
return Response(
|
||||
status="success",
|
||||
message="保存成功~ 机器人将在 2 秒内重启以应用新的配置。",
|
||||
message="",
|
||||
data=self._get_astrbot_config()
|
||||
).__dict__
|
||||
return Response(
|
||||
status="success",
|
||||
message="",
|
||||
data=self._get_extension_config(namespace)
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/astrbot-configs")
|
||||
def post_astrbot_configs():
|
||||
post_configs = request.json
|
||||
try:
|
||||
self.save_astrbot_configs(post_configs)
|
||||
return Response(
|
||||
status="success",
|
||||
message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=self.dashboard_data.configs
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.post("/api/extension-configs")
|
||||
def post_extension_configs():
|
||||
post_configs = request.json
|
||||
try:
|
||||
self.save_extension_configs(post_configs)
|
||||
return Response(
|
||||
status="success",
|
||||
message="保存成功~ 机器人将在 3 秒内重启以应用新的配置。",
|
||||
data=None
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
return Response(
|
||||
status="error",
|
||||
message=e.__str__(),
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
@self.dashboard_be.get("/api/extensions")
|
||||
@@ -365,107 +380,41 @@ class AstrBotDashBoard():
|
||||
data=None
|
||||
).__dict__
|
||||
|
||||
def on_post_configs(self, post_configs: dict):
|
||||
def save_astrbot_configs(self, post_configs: dict):
|
||||
try:
|
||||
if 'base_config' in post_configs:
|
||||
self.dashboard_helper.save_config(
|
||||
post_configs['base_config'], namespace='') # 基础配置
|
||||
self.dashboard_helper.save_config(
|
||||
post_configs['config'], namespace=post_configs['namespace']) # 选定配置
|
||||
self.dashboard_helper.parse_default_config(
|
||||
self.dashboard_data, self.context.config_helper.get_all())
|
||||
# 重启
|
||||
threading.Thread(target=self.astrbot_updator._reboot,
|
||||
args=(2, self.context), daemon=True).start()
|
||||
self.dashboard_helper.save_astrbot_config(post_configs)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def save_extension_configs(self, post_configs: dict):
|
||||
try:
|
||||
self.dashboard_helper.save_extension_config(post_configs)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(3, self.context), daemon=True).start()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _get_astrbot_config(self):
|
||||
config = self.context.config_helper.to_dict()
|
||||
for key in self.dashboard_helper.config_key_dont_show:
|
||||
if key in config:
|
||||
del config[key]
|
||||
return {
|
||||
"metadata": CONFIG_METADATA_2,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
def _get_configs(self, namespace: str):
|
||||
if namespace == "":
|
||||
ret = [self.dashboard_data.configs['data'][4],
|
||||
self.dashboard_data.configs['data'][5],]
|
||||
elif namespace == "internal_platform_qq_official":
|
||||
ret = [self.dashboard_data.configs['data'][0],]
|
||||
elif namespace == "internal_platform_qq_gocq":
|
||||
ret = [self.dashboard_data.configs['data'][1],]
|
||||
elif namespace == "internal_platform_general": # 全局平台配置
|
||||
ret = [self.dashboard_data.configs['data'][2],]
|
||||
elif namespace == "internal_llm_openai_official":
|
||||
ret = [self.dashboard_data.configs['data'][3],]
|
||||
elif namespace == "internal_platform_qq_aiocqhttp":
|
||||
ret = [self.dashboard_data.configs['data'][6],]
|
||||
else:
|
||||
path = f"data/config/{namespace}.json"
|
||||
if not os.path.exists(path):
|
||||
return []
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
ret = [{
|
||||
"config_type": "group",
|
||||
"name": namespace + " 插件配置",
|
||||
"description": "",
|
||||
"body": list(json.load(f).values())
|
||||
},]
|
||||
return ret
|
||||
|
||||
def _generate_outline(self):
|
||||
'''
|
||||
生成配置大纲。目前分为 platform(消息平台配置) 和 llm(语言模型配置) 两大类。
|
||||
插件的info函数中如果带了plugin_type字段,则会被归类到对应的大纲中。目前仅支持 platform 和 llm 两种类型。
|
||||
'''
|
||||
outline = [
|
||||
{
|
||||
"type": "platform",
|
||||
"name": "配置通用消息平台",
|
||||
"body": [
|
||||
{
|
||||
"title": "通用",
|
||||
"desc": "通用平台配置",
|
||||
"namespace": "internal_platform_general",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(官方)",
|
||||
"desc": "QQ官方API。支持频道、群、私聊(需获得群权限)",
|
||||
"namespace": "internal_platform_qq_official",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(nakuru)",
|
||||
"desc": "适用于 go-cqhttp",
|
||||
"namespace": "internal_platform_qq_gocq",
|
||||
"tag": ""
|
||||
},
|
||||
{
|
||||
"title": "QQ(aiocqhttp)",
|
||||
"desc": "适用于 Lagrange, LLBot, Shamrock 等支持反向WS的协议实现。",
|
||||
"namespace": "internal_platform_qq_aiocqhttp",
|
||||
"tag": ""
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "llm",
|
||||
"name": "配置 LLM",
|
||||
"body": [
|
||||
{
|
||||
"title": "OpenAI Official",
|
||||
"desc": "也支持使用官方接口的中转服务",
|
||||
"namespace": "internal_llm_openai_official",
|
||||
"tag": ""
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
for plugin in self.context.cached_plugins:
|
||||
for item in outline:
|
||||
if item['type'] == plugin.metadata.plugin_type:
|
||||
item['body'].append({
|
||||
"title": plugin.metadata.plugin_name,
|
||||
"desc": plugin.metadata.desc,
|
||||
"namespace": plugin.metadata.plugin_name,
|
||||
"tag": plugin.metadata.plugin_name
|
||||
})
|
||||
return outline
|
||||
def _get_extension_config(self, namespace: str):
|
||||
path = f"data/config/{namespace}.json"
|
||||
if not os.path.exists(path):
|
||||
return []
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
return [{
|
||||
"config_type": "group",
|
||||
"name": namespace + " 插件配置",
|
||||
"description": "",
|
||||
"body": list(json.load(f).values())
|
||||
},]
|
||||
|
||||
async def get_log_history(self):
|
||||
try:
|
||||
|
||||
@@ -8,7 +8,6 @@ from type.types import Context
|
||||
from type.config import VERSION
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from nakuru.entities.components import Image
|
||||
from util.agent.web_searcher import search_from_bing, fetch_website_content
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
@@ -63,12 +62,12 @@ class InternalCommandHandler:
|
||||
return CommandResult().message("你没有权限使用该指令。")
|
||||
l = message_str.split(" ")
|
||||
if len(l) == 1:
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词有:{context.nick}")
|
||||
return CommandResult().message(f"设置机器人唤醒词。以唤醒词开头的消息会唤醒机器人处理,起到 @ 的效果。\n示例:wake 昵称。当前唤醒词是:{context.config_helper.wake_prefix[0]}")
|
||||
nick = l[1].strip()
|
||||
if not nick:
|
||||
return CommandResult().message("wake: 请指定唤醒词。")
|
||||
context.config_helper.put("nick_qq", nick)
|
||||
context.nick = tuple(nick)
|
||||
context.config_helper.wake_prefix = [nick]
|
||||
context.config_helper.save_config()
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
@@ -90,11 +89,7 @@ class InternalCommandHandler:
|
||||
ret = f"当前已经是最新版本 v{VERSION}。"
|
||||
else:
|
||||
ret = f"发现新版本 {update_info.version},更新内容如下:\n---\n{update_info.body}\n---\n- 使用 /update latest 更新到最新版本。\n- 使用 /update vX.X.X 更新到指定版本。"
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=False,
|
||||
message_chain=ret,
|
||||
)
|
||||
return CommandResult().message(ret)
|
||||
else:
|
||||
if tokens.get(1) == "latest":
|
||||
try:
|
||||
@@ -184,7 +179,7 @@ class InternalCommandHandler:
|
||||
async with session.get("https://soulter.top/channelbot/notice.json") as resp:
|
||||
notice = (await resp.json())["notice"]
|
||||
except BaseException as e:
|
||||
logger.warn("An error occurred while fetching astrbot notice. Never mind, it's not important.")
|
||||
logger.warning("An error occurred while fetching astrbot notice. Never mind, it's not important.")
|
||||
|
||||
msg = "# Help Center\n## 指令列表\n"
|
||||
for key, value in self.manager.commands_handler.items():
|
||||
@@ -209,10 +204,11 @@ class InternalCommandHandler:
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain=f"网页搜索功能当前状态: {context.web_search}",
|
||||
message_chain=f"网页搜索功能当前状态: {context.config_helper.llm_settings.web_search}",
|
||||
)
|
||||
elif l[1] == 'on':
|
||||
context.web_search = True
|
||||
context.config_helper.llm_settings.web_search = True
|
||||
context.config_helper.save_config()
|
||||
context.register_llm_tool("web_search", [{
|
||||
"type": "string",
|
||||
"name": "keyword",
|
||||
@@ -236,7 +232,8 @@ class InternalCommandHandler:
|
||||
message_chain="已开启网页搜索",
|
||||
)
|
||||
elif l[1] == 'off':
|
||||
context.web_search = False
|
||||
context.config_helper.llm_settings.web_search = False
|
||||
context.config_helper.save_config()
|
||||
context.unregister_llm_tool("web_search")
|
||||
context.unregister_llm_tool("fetch_website_content")
|
||||
|
||||
@@ -253,17 +250,17 @@ class InternalCommandHandler:
|
||||
)
|
||||
|
||||
def t2i_toggle(self, message: AstrMessageEvent, context: Context):
|
||||
p = context.t2i_mode
|
||||
p = context.config_helper.t2i
|
||||
if p:
|
||||
context.config_helper.put("qq_pic_mode", False)
|
||||
context.t2i_mode = False
|
||||
context.config_helper.t2i = False
|
||||
context.config_helper.save_config()
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
success=True,
|
||||
message_chain="已关闭文本转图片模式。",
|
||||
)
|
||||
context.config_helper.put("qq_pic_mode", True)
|
||||
context.t2i_mode = True
|
||||
context.config_helper.t2i = True
|
||||
context.config_helper.save_config()
|
||||
|
||||
return CommandResult(
|
||||
hit=True,
|
||||
|
||||
@@ -61,10 +61,11 @@ class Platform():
|
||||
return ret[:100] if len(ret) > 100 else ret
|
||||
|
||||
def check_nick(self, message_str: str) -> bool:
|
||||
if self.context.nick:
|
||||
for nick in self.context.nick:
|
||||
if nick and message_str.strip().startswith(nick):
|
||||
return True
|
||||
w = self.context.config_helper.wake_prefix
|
||||
if not w: return False
|
||||
for nick in w:
|
||||
if nick and message_str.strip().startswith(nick):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def convert_to_t2i_chain(self, message_result: list) -> list:
|
||||
|
||||
+33
-23
@@ -6,6 +6,12 @@ from type.types import Context
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import (
|
||||
PlatformConfig,
|
||||
AiocqhttpPlatformConfig,
|
||||
NakuruPlatformConfig,
|
||||
QQOfficialPlatformConfig
|
||||
)
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -13,36 +19,40 @@ logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
class PlatformManager():
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
self.context = context
|
||||
self.config = context.base_config
|
||||
self.msg_handler = message_handler
|
||||
|
||||
def load_platforms(self):
|
||||
tasks = []
|
||||
|
||||
if 'gocqbot' in self.config and self.config['gocqbot']['enable']:
|
||||
logger.info("启用 QQ(nakuru 适配器)")
|
||||
tasks.append(asyncio.create_task(self.gocq_bot(), name="nakuru-adapter"))
|
||||
|
||||
if 'aiocqhttp' in self.config and self.config['aiocqhttp']['enable']:
|
||||
logger.info("启用 QQ(aiocqhttp 适配器)")
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot(), name="aiocqhttp-adapter"))
|
||||
platforms = self.context.config_helper.platform
|
||||
logger.info(f"加载 {len(platforms)} 个机器人消息平台...")
|
||||
for platform in platforms:
|
||||
if not platform.enable:
|
||||
continue
|
||||
if platform.name == "qq_official":
|
||||
assert isinstance(platform, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
|
||||
logger.info(f"加载 QQ官方 机器人消息平台 (appid: {platform.appid})")
|
||||
tasks.append(asyncio.create_task(self.qqofficial_bot(platform), name="qqofficial-adapter"))
|
||||
elif platform.name == "nakuru":
|
||||
assert isinstance(platform, NakuruPlatformConfig), "nakuru: 无法识别的配置类型。"
|
||||
logger.info(f"加载 QQ(nakuru) 机器人消息平台 ({platform.host}, {platform.websocket_port}, {platform.port})")
|
||||
tasks.append(asyncio.create_task(self.nakuru_bot(platform), name="nakuru-adapter"))
|
||||
elif platform.name == "aiocqhttp":
|
||||
assert isinstance(platform, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
|
||||
logger.info("加载 QQ(aiocqhttp) 机器人消息平台")
|
||||
tasks.append(asyncio.create_task(self.aiocq_bot(platform), name="aiocqhttp-adapter"))
|
||||
|
||||
# QQ频道
|
||||
if 'qqbot' in self.config and self.config['qqbot']['enable'] and self.config['qqbot']['appid'] != None:
|
||||
logger.info("启用 QQ(官方 API) 机器人消息平台")
|
||||
tasks.append(asyncio.create_task(self.qqchan_bot(), name="qqofficial-adapter"))
|
||||
|
||||
return tasks
|
||||
|
||||
async def gocq_bot(self):
|
||||
async def nakuru_bot(self, config: NakuruPlatformConfig):
|
||||
'''
|
||||
运行 QQ(nakuru 适配器)
|
||||
'''
|
||||
from model.platform.qq_nakuru import QQGOCQ
|
||||
from model.platform.qq_nakuru import QQNakuru
|
||||
noticed = False
|
||||
host = self.config.get("gocq_host", "127.0.0.1")
|
||||
port = self.config.get("gocq_websocket_port", 6700)
|
||||
http_port = self.config.get("gocq_http_port", 5700)
|
||||
host = config.host
|
||||
port = config.websocket_port
|
||||
http_port = config.port
|
||||
logger.info(
|
||||
f"正在检查连接...host: {host}, ws port: {port}, http port: {http_port}")
|
||||
while True:
|
||||
@@ -56,30 +66,30 @@ class PlatformManager():
|
||||
logger.info("nakuru 适配器已连接。")
|
||||
break
|
||||
try:
|
||||
qq_gocq = QQGOCQ(self.context, self.msg_handler)
|
||||
qq_gocq = QQNakuru(self.context, self.msg_handler, config)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="nakuru", platform_instance=qq_gocq, origin="internal"))
|
||||
await qq_gocq.run()
|
||||
except BaseException as e:
|
||||
logger.error("启动 nakuru 适配器时出现错误: " + str(e))
|
||||
|
||||
def aiocq_bot(self):
|
||||
def aiocq_bot(self, config):
|
||||
'''
|
||||
运行 QQ(aiocqhttp 适配器)
|
||||
'''
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler)
|
||||
qq_aiocqhttp = AIOCQHTTP(self.context, self.msg_handler, config)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="aiocqhttp", platform_instance=qq_aiocqhttp, origin="internal"))
|
||||
return qq_aiocqhttp.run_aiocqhttp()
|
||||
|
||||
def qqchan_bot(self):
|
||||
def qqofficial_bot(self, config):
|
||||
'''
|
||||
运行 QQ 官方机器人适配器
|
||||
'''
|
||||
try:
|
||||
from model.platform.qq_official import QQOfficial
|
||||
qqchannel_bot = QQOfficial(self.context, self.msg_handler)
|
||||
qqchannel_bot = QQOfficial(self.context, self.msg_handler, config)
|
||||
self.context.platforms.append(RegisteredPlatform(
|
||||
platform_name="qqofficial", platform_instance=qqchannel_bot, origin="internal"))
|
||||
return qqchannel_bot.run()
|
||||
|
||||
@@ -13,19 +13,26 @@ from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, AiocqhttpPlatformConfig
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
class AIOCQHTTP(Platform):
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig) -> None:
|
||||
super().__init__("aiocqhttp", context)
|
||||
assert isinstance(platform_config, AiocqhttpPlatformConfig), "aiocqhttp: 无法识别的配置类型。"
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.waiting = {}
|
||||
self.context = context
|
||||
self.unique_session = self.context.unique_session
|
||||
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
|
||||
self.host = self.context.base_config['aiocqhttp']['ws_reverse_host']
|
||||
self.port = self.context.base_config['aiocqhttp']['ws_reverse_port']
|
||||
self.config = platform_config
|
||||
self.unique_session = context.config_helper.platform_settings.unique_session
|
||||
self.announcement = context.config_helper.platform_settings.welcome_message_when_join
|
||||
self.host = platform_config.ws_reverse_host
|
||||
self.port = platform_config.ws_reverse_port
|
||||
self.admins = context.config_helper.admins_id
|
||||
|
||||
def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
|
||||
@@ -105,7 +112,7 @@ class AIOCQHTTP(Platform):
|
||||
while self.context.running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
async def pre_check(self, message: AstrBotMessage) -> bool:
|
||||
# if message chain contains Plain components or
|
||||
# At components which points to self_id, return True
|
||||
if message.type == MessageType.FRIEND_MESSAGE:
|
||||
@@ -114,7 +121,7 @@ class AIOCQHTTP(Platform):
|
||||
if isinstance(comp, At) and str(comp.qq) == message.self_id:
|
||||
return True, "at"
|
||||
# check commands which ignore prefix
|
||||
if self.context.command_manager.check_command_ignore_prefix(message.message_str):
|
||||
if await self.context.command_manager.check_command_ignore_prefix(message.message_str):
|
||||
return True, "command"
|
||||
# check nicks
|
||||
if self.check_nick(message.message_str):
|
||||
@@ -125,14 +132,13 @@ class AIOCQHTTP(Platform):
|
||||
logger.info(
|
||||
f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}")
|
||||
|
||||
ok, reason = self.pre_check(message)
|
||||
ok, reason = await self.pre_check(message)
|
||||
if not ok:
|
||||
return
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.sender.user_id)
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
if sender_id in self.admins:
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -167,6 +173,8 @@ class AIOCQHTTP(Platform):
|
||||
# 如果是等待回复的消息
|
||||
if message.session_id in self.waiting and self.waiting[message.session_id] == '':
|
||||
self.waiting[message.session_id] = message
|
||||
|
||||
return message_result
|
||||
|
||||
|
||||
async def reply_msg(self,
|
||||
@@ -182,17 +190,18 @@ class AIOCQHTTP(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
if (use_t2i or (use_t2i == None and self.context.config_helper.t2i)) and isinstance(result_message, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
await self._reply(message, rendered_images)
|
||||
return
|
||||
return rendered_images
|
||||
except BaseException as e:
|
||||
logger.warn(traceback.format_exc())
|
||||
logger.warn(f"以文本转图片的形式回复消息时发生错误: {e},将尝试默认方式。")
|
||||
|
||||
await self._reply(message, res)
|
||||
return res
|
||||
|
||||
async def _reply(self, message: Union[AstrBotMessage, Dict], message_chain: List[BaseMessageComponent]):
|
||||
await self.record_metrics()
|
||||
|
||||
+22
-16
@@ -18,6 +18,7 @@ from type.command import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, NakuruPlatformConfig
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -28,47 +29,53 @@ class FakeSource:
|
||||
self.group_id = group_id
|
||||
|
||||
|
||||
class QQGOCQ(Platform):
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
class QQNakuru(Platform):
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig) -> None:
|
||||
super().__init__("nakuru", context)
|
||||
assert isinstance(platform_config, NakuruPlatformConfig), "gocq: 无法识别的配置类型。"
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.waiting = {}
|
||||
self.context = context
|
||||
self.unique_session = self.context.unique_session
|
||||
self.announcement = self.context.base_config.get("announcement", "欢迎新人!")
|
||||
|
||||
self.unique_session = context.config_helper.platform_settings.unique_session
|
||||
self.announcement = context.config_helper.platform_settings.welcome_message_when_join
|
||||
self.config = platform_config
|
||||
self.admins = context.config_helper.admins_id
|
||||
|
||||
self.client = CQHTTP(
|
||||
host=self.context.base_config.get("gocq_host", "127.0.0.1"),
|
||||
port=self.context.base_config.get("gocq_websocket_port", 6700),
|
||||
http_port=self.context.base_config.get("gocq_http_port", 5700),
|
||||
host=self.config.host,
|
||||
port=self.config.websocket_port,
|
||||
http_port=self.config.port
|
||||
)
|
||||
gocq_app = self.client
|
||||
|
||||
@gocq_app.receiver("GroupMessage")
|
||||
async def _(app: CQHTTP, source: GroupMessage):
|
||||
if self.context.base_config.get("gocq_react_group", True):
|
||||
if self.config.enable_group:
|
||||
abm = self.convert_message(source)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@gocq_app.receiver("FriendMessage")
|
||||
async def _(app: CQHTTP, source: FriendMessage):
|
||||
if self.context.base_config.get("gocq_react_friend", True):
|
||||
if self.config.enable_direct_message:
|
||||
abm = self.convert_message(source)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@gocq_app.receiver("GroupMemberIncrease")
|
||||
async def _(app: CQHTTP, source: GroupMemberIncrease):
|
||||
if self.context.base_config.get("gocq_react_group_increase", True):
|
||||
if self.config.enable_group_increase:
|
||||
await app.sendGroupMessage(source.group_id, [
|
||||
Plain(text=self.announcement)
|
||||
])
|
||||
|
||||
@gocq_app.receiver("GuildMessage")
|
||||
async def _(app: CQHTTP, source: GuildMessage):
|
||||
if self.cc.get("gocq_react_guild", True):
|
||||
if self.config.enable_guild:
|
||||
abm = self.convert_message(source)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -117,8 +124,7 @@ class QQGOCQ(Platform):
|
||||
|
||||
# 解析 role
|
||||
sender_id = str(message.raw_message.user_id)
|
||||
if sender_id == self.context.base_config.get('admin_qq', '') or \
|
||||
sender_id in self.context.base_config.get('other_admins', []):
|
||||
if sender_id in self.admins:
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -179,7 +185,7 @@ class QQGOCQ(Platform):
|
||||
res = [Plain(text=res), ]
|
||||
|
||||
# if image mode, put all Plain texts into a new picture.
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(res)
|
||||
if rendered_images:
|
||||
try:
|
||||
@@ -226,7 +232,7 @@ class QQGOCQ(Platform):
|
||||
plain_text_len += len(i.text)
|
||||
elif isinstance(i, Image):
|
||||
image_num += 1
|
||||
if plain_text_len > self.context.base_config.get('qq_forward_threshold', 200):
|
||||
if plain_text_len > self.context.config_helper.platform_settings.forward_threshold or image_num > 1:
|
||||
# 删除At
|
||||
for i in message_chain:
|
||||
if isinstance(i, At):
|
||||
|
||||
@@ -19,6 +19,7 @@ from nakuru.entities.components import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from astrbot.message.handler import MessageHandler
|
||||
from util.cmd_config import PlatformConfig, QQOfficialPlatformConfig
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -52,32 +53,37 @@ class botClient(Client):
|
||||
|
||||
class QQOfficial(Platform):
|
||||
|
||||
def __init__(self, context: Context, message_handler: MessageHandler) -> None:
|
||||
def __init__(self, context: Context,
|
||||
message_handler: MessageHandler,
|
||||
platform_config: PlatformConfig,
|
||||
test_mode = False) -> None:
|
||||
super().__init__("qqofficial", context)
|
||||
assert isinstance(platform_config, QQOfficialPlatformConfig), "qq_official: 无法识别的配置类型。"
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.waiting: dict = {}
|
||||
self.context = context
|
||||
self.config = platform_config
|
||||
self.admins = context.config_helper.admins_id
|
||||
|
||||
self.appid = context.base_config['qqbot']['appid']
|
||||
self.token = context.base_config['qqbot']['token']
|
||||
self.secret = context.base_config['qqbot_secret']
|
||||
self.unique_session = context.unique_session
|
||||
qq_group = context.base_config['qqofficial_enable_group_message']
|
||||
|
||||
self.appid = platform_config.appid
|
||||
self.secret = platform_config.secret
|
||||
self.unique_session = context.config_helper.platform_settings.unique_session
|
||||
qq_group = platform_config.enable_group_c2c
|
||||
guild_dm = platform_config.enable_guild_direct_message
|
||||
|
||||
if qq_group:
|
||||
self.intents = botpy.Intents(
|
||||
public_messages=True,
|
||||
public_guild_messages=True,
|
||||
direct_message=context.base_config['direct_message_mode']
|
||||
direct_message=guild_dm
|
||||
)
|
||||
else:
|
||||
self.intents = botpy.Intents(
|
||||
public_guild_messages=True,
|
||||
direct_message=context.base_config['direct_message_mode']
|
||||
direct_message=guild_dm
|
||||
)
|
||||
self.client = botClient(
|
||||
intents=self.intents,
|
||||
@@ -169,24 +175,10 @@ class QQOfficial(Platform):
|
||||
return abm
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
secret=self.secret
|
||||
)
|
||||
except BaseException as e:
|
||||
# 早期的 qq-botpy 版本使用 token 登录。
|
||||
logger.error(traceback.format_exc())
|
||||
self.client = botClient(
|
||||
intents=self.intents,
|
||||
bot_log=False,
|
||||
timeout=20,
|
||||
)
|
||||
self.client.set_platform(self)
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
token=self.token
|
||||
)
|
||||
return self.client.start(
|
||||
appid=self.appid,
|
||||
secret=self.secret
|
||||
)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
assert isinstance(message.raw_message, (botpy.message.Message,
|
||||
@@ -211,8 +203,7 @@ class QQOfficial(Platform):
|
||||
|
||||
# 解析出 role
|
||||
sender_id = message.sender.user_id
|
||||
if sender_id == self.context.base_config.get('admin_qqchan', None) or \
|
||||
sender_id in self.context.base_config.get('other_admins', None):
|
||||
if sender_id in self.admins:
|
||||
role = 'admin'
|
||||
else:
|
||||
role = 'member'
|
||||
@@ -252,7 +243,7 @@ class QQOfficial(Platform):
|
||||
msg_ref = None
|
||||
rendered_images = []
|
||||
|
||||
if use_t2i or (use_t2i == None and self.context.base_config.get("qq_pic_mode", False)) and isinstance(res, list):
|
||||
if use_t2i or (use_t2i == None and self.context.config_helper.t2i) and isinstance(result_message, list):
|
||||
rendered_images = await self.convert_to_t2i_chain(result_message)
|
||||
|
||||
if isinstance(result_message, list):
|
||||
|
||||
@@ -8,18 +8,18 @@ import traceback
|
||||
import base64
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.images_response import ImagesResponse
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai._exceptions import *
|
||||
from util.io import download_image_by_url
|
||||
|
||||
from astrbot.persist.helper import dbConn
|
||||
from model.provider.provider import Provider
|
||||
from util.cmd_config import LLMConfig
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Logger
|
||||
from typing import List, Dict
|
||||
|
||||
from type.types import Context
|
||||
from dataclasses import asdict
|
||||
|
||||
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -47,22 +47,16 @@ MODELS = {
|
||||
}
|
||||
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def __init__(self, context: Context) -> None:
|
||||
def __init__(self, llm_config: LLMConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
os.makedirs("data/openai", exist_ok=True)
|
||||
|
||||
self.context = context
|
||||
self.key_data_path = "data/openai/keys.json"
|
||||
self.api_keys = []
|
||||
self.chosen_api_key = None
|
||||
self.base_url = None
|
||||
self.llm_config = llm_config
|
||||
self.keys_data = {} # 记录超额
|
||||
|
||||
cfg = context.base_config['openai']
|
||||
|
||||
if cfg['key']: self.api_keys = cfg['key']
|
||||
if cfg['api_base']: self.base_url = cfg['api_base']
|
||||
if llm_config.key: self.api_keys = llm_config.key
|
||||
if llm_config.api_base: self.base_url = llm_config.api_base
|
||||
if not self.api_keys:
|
||||
logger.warn("看起来你没有添加 OpenAI 的 API 密钥,OpenAI LLM 能力将不会启用。")
|
||||
else:
|
||||
@@ -75,18 +69,21 @@ class ProviderOpenAIOfficial(Provider):
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
self.model_configs: Dict = cfg['chatGPTConfigs']
|
||||
super().set_curr_model(self.model_configs['model'])
|
||||
self.image_generator_model_configs: Dict = context.base_config.get('openai_image_generate', None)
|
||||
super().set_curr_model(llm_config.model_config.model)
|
||||
if llm_config.image_generation_model_config:
|
||||
self.image_generator_model_configs: Dict = asdict(llm_config.image_generation_model_config)
|
||||
self.session_memory: Dict[str, List] = {} # 会话记忆
|
||||
self.session_memory_lock = threading.Lock()
|
||||
self.max_tokens = self.model_configs['max_tokens'] # 上下文窗口大小
|
||||
self.max_tokens = self.llm_config.model_config.max_tokens # 上下文窗口大小
|
||||
|
||||
logger.info("正在载入分词器 cl100k_base...")
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base") # todo: 根据 model 切换分词器
|
||||
logger.info("分词器载入完成。")
|
||||
|
||||
self.DEFAULT_PERSONALITY = context.default_personality
|
||||
self.DEFAULT_PERSONALITY = {
|
||||
"prompt": self.llm_config.default_personality,
|
||||
"name": "default"
|
||||
}
|
||||
self.curr_personality = self.DEFAULT_PERSONALITY
|
||||
self.session_personality = {} # 记录了某个session是否已设置人格。
|
||||
# 从 SQLite DB 读取历史记录
|
||||
@@ -133,7 +130,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.session_personality = {} # 重置
|
||||
encoded_prompt = self.tokenizer.encode(default_personality['prompt'])
|
||||
tokens_num = len(encoded_prompt)
|
||||
model = self.model_configs['model']
|
||||
model = self.get_curr_model()
|
||||
if model in MODELS and tokens_num > MODELS[model] - 500:
|
||||
default_personality['prompt'] = self.tokenizer.decode(encoded_prompt[:MODELS[model] - 500])
|
||||
|
||||
@@ -172,7 +169,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for record in self.session_memory[session_id]:
|
||||
if "user" in record and record['user']:
|
||||
if not is_lvm and "content" in record['user'] and isinstance(record['user']['content'], list):
|
||||
logger.warn(f"由于当前模型 {self.model_configs['model']}不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
|
||||
logger.warn(f"由于当前模型 {self.get_curr_model()} 不支持视觉,将忽略上下文中的图片输入。如果一直弹出此警告,可以尝试 reset 指令。")
|
||||
continue
|
||||
context.append(record['user'])
|
||||
if "AI" in record and record['AI']:
|
||||
@@ -184,7 +181,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
是否是 LVM
|
||||
'''
|
||||
return self.model_configs['model'].startswith("gpt-4")
|
||||
return self.get_curr_model().startswith("gpt-4")
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
@@ -237,7 +234,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.session_memory[session_id].append(message)
|
||||
|
||||
# 根据 模型的上下文窗口 淘汰掉多余的记录
|
||||
curr_model = self.model_configs['model']
|
||||
curr_model = self.get_curr_model()
|
||||
if curr_model in MODELS:
|
||||
maxium_tokens_num = MODELS[curr_model] - 300 # 至少预留 300 给 completion
|
||||
# if message['usage_tokens'] > maxium_tokens_num:
|
||||
@@ -316,7 +313,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 1. 可以保证之后 pop 的时候不会出现问题
|
||||
# 2. 可以保证不会超过最大 token 数
|
||||
_encoded_prompt = self.tokenizer.encode(prompt)
|
||||
curr_model = self.model_configs['model']
|
||||
curr_model = self.get_curr_model()
|
||||
if curr_model in MODELS and len(_encoded_prompt) > MODELS[curr_model] - 300:
|
||||
_encoded_prompt = _encoded_prompt[:MODELS[curr_model] - 300]
|
||||
prompt = self.tokenizer.decode(_encoded_prompt)
|
||||
@@ -327,7 +324,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# 获取上下文,openai 格式
|
||||
contexts = await self.retrieve_context(session_id)
|
||||
|
||||
conf = self.model_configs
|
||||
conf = asdict(self.llm_config.model_config)
|
||||
if extra_conf: conf.update(extra_conf)
|
||||
|
||||
# start request
|
||||
@@ -358,10 +355,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if ok: continue
|
||||
else: raise Exception("所有 OpenAI API Key 目前都不可用。")
|
||||
except BadRequestError as e:
|
||||
retry += 1
|
||||
logger.warn(f"OpenAI 请求异常:{e}。")
|
||||
if "image_url is only supported by certain models." in str(e):
|
||||
raise Exception(f"当前模型 { self.model_configs['model'] } 不支持图片输入,请更换模型。")
|
||||
raise e
|
||||
raise Exception(f"当前模型 { self.get_curr_model() } 不支持图片输入,请更换模型。")
|
||||
except RateLimitError as e:
|
||||
if "You exceeded your current quota" in str(e):
|
||||
self.keys_data[self.chosen_api_key] = False
|
||||
@@ -437,11 +434,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
retry = 0
|
||||
conf = self.image_generator_model_configs
|
||||
super().accu_model_stat(model=conf['model'])
|
||||
if not conf:
|
||||
logger.error("OpenAI 图片生成模型配置不存在。")
|
||||
raise Exception("OpenAI 图片生成模型配置不存在。")
|
||||
|
||||
super().accu_model_stat(model=conf['model'])
|
||||
while retry < 3:
|
||||
try:
|
||||
images_response = await self.client.images.generate(
|
||||
@@ -497,12 +493,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return contexts_str, len(self.session_memory[session_id])
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model_configs['model'] = model
|
||||
self.context.config_helper.put_by_dot_str("openai.chatGPTConfigs.model", model)
|
||||
# TODO: 更新配置文件
|
||||
super().set_curr_model(model)
|
||||
|
||||
def get_configs(self):
|
||||
return self.model_configs
|
||||
return asdict(self.llm_config)
|
||||
|
||||
def get_keys_data(self):
|
||||
return self.keys_data
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from aiocqhttp import Event
|
||||
|
||||
class MockOneBotMessage():
|
||||
@@ -10,4 +11,8 @@ class MockOneBotMessage():
|
||||
return self.group_event_sample
|
||||
|
||||
def create_random_direct_message(self):
|
||||
return self.friend_event_sample
|
||||
return self.friend_event_sample
|
||||
|
||||
def create_msg(self, text: str):
|
||||
self.group_event_sample.message = [{'data': {'qq': '3430871669'}, 'type': 'at'}, {'data': {'text': text}, 'type': 'text'}]
|
||||
return self.group_event_sample
|
||||
@@ -3,19 +3,19 @@ import botpy.message
|
||||
class MockQQOfficialMessage():
|
||||
def __init__(self):
|
||||
# 这些数据已经经过去敏处理
|
||||
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-s7tOAbAq.IwuxikQF99Zo0ZBTGwimNMI9tHdSVqDwLokBtxf6ZR0.wT2ZicHpFjKstG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T19:58:52+08:00'}
|
||||
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-gPHZcYCXwRupoe8vE-ZOTrTxu7SAaxnZZpw5EcmZ2njqYIyLrdKiL0AQzPPUtGntMtG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:06:32+08:00'}
|
||||
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_sS6HqVPgtqV99eGliL-B-sxsf5-CTemxnIrv6O3G6ZYZ6EVI3I2Z4wNye7dUiKuyvRiHM9aM.-tTLCT.qsJy1stG81ovPjw88HwjHppK6Gc!', 'timestamp': '2024-07-27T20:15:24+08:00'}
|
||||
self.group_plain_text_sample = {'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': 'just reply me `ok`', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T19:58:52+08:00'}
|
||||
self.group_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_Cii9ibiql8eHA1CAvaMB&rkey=CAESKE4_cASDm1t162vI7q9gitU2u0SUciVRg1fbyn3zYe9f_XHL2vhiB0s&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': ' ', 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:06:32+08:00'}
|
||||
self.group_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'size': 1440173, 'url': 'https://multimedia.nt.qq.com.cn/download?appid=1407&fileid=Cgk5MDU2MTc5OTISFBvbdDR6nYEHsqWEfYauN9wphLxlGK3zVyD_CiiMytyomceHA1CAvaMB&rkey=CAQSKDOc_jvbthUjVk7zSzPCqflD2XWA0OWzO5qCNsiRFY4RfQMuHYt8KDU&spec=0', 'width': 1186}], 'author': {'id': '3E47ABD92415AFEF02DAD74FFAB592D1', 'member_openid': '3E47ABD92415AFEF02DAD74FFAB592D1'}, 'content': " What's this", 'group_id': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'group_openid': 'BF5D5CA67932FFC4AFD18D4309DB759D', 'id': 'ROBOT1.0_test', 'timestamp': '2024-07-27T20:15:24+08:00'}
|
||||
self.group_event_id_sample = "GROUP_AT_MESSAGE_CREATE:ss6hqvpgtqv99eglilbjpsdzvudsjev64th8srgofxqkgxwpynhysl6q6ws849"
|
||||
|
||||
self.guild_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> just reply me `ok`', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438ef0e48a6c793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=OUbv2LTECcjQt48ibDS4OcA&kti=ZqTjpgAAAAI&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1903, 'seq_in_channel': '1903', 'timestamp': '2024-07-27T20:10:14+08:00'}
|
||||
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f10e48dbc793b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
|
||||
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': '08ffca96ebdaa68fcd6e108de3de0438f30e48a2c993b506', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
|
||||
self.guild_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2665728996', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2665728996-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': '<@!2519660939131724751> ', 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1905, 'seq_in_channel': '1905', 'timestamp': '2024-07-27T20:11:07+08:00'}
|
||||
self.guild_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2501183002', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/75802001660367636/9941389-2501183002-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'bot': False, 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '9941389', 'content': "<@!2519660939131724751> What's this", 'guild_id': '7969749791337194879', 'id': 'testid', 'member': {'joined_at': '2022-08-13T13:13:56+08:00', 'nick': 'Soulter', 'roles': ['4', '23']}, 'mentions': [{'avatar': 'http://thirdqq.qlogo.cn/g?b=oidb&k=mZ2Hn0BN5MLlBJTve0WIoA&kti=ZqTjnwAAAAA&s=0&t=1708501824', 'bot': True, 'id': '2519660939131724751', 'username': '浅橙Bot'}], 'seq': 1907, 'seq_in_channel': '1907', 'timestamp': '2024-07-27T20:14:26+08:00'}
|
||||
self.guild_event_id_sample = "AT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
self.direct_plain_text_sample = {'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': 'just reply me `ok`', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a5014898c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 165, 'seq_in_channel': '165', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:08+08:00'}
|
||||
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a70148adc893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
|
||||
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': '08caaea38bcaabbe942f10afaf8fb08fa49d3b38a80148f2c893b506', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
|
||||
self.direct_plain_image_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2658044992', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2658044992-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 167, 'seq_in_channel': '167', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:12:29+08:00'}
|
||||
self.direct_multimedia_sample = {'attachments': [{'content_type': 'image/png', 'filename': '165FCBF8BD6F42496B58A6C66C5D4255.png', 'height': 1034, 'id': '2526212938', 'size': 1440173, 'url': 'gchat.qpic.cn/qmeetpic/92265551678707631/33342831678707631-2526212938-165FCBF8BD6F42496B58A6C66C5D4255/0', 'width': 1186}], 'author': {'avatar': 'https://qqchannel-profile-1251316161.file.myqcloud.com/168087977775f0eae70da8e512?t=1680879777', 'id': '6946931796791550499', 'username': 'Soulter'}, 'channel_id': '33342831678707631', 'content': "What's this", 'direct_message': True, 'guild_id': '3398240095091349322', 'id': 'testid', 'member': {'joined_at': '2023-03-13T19:40:31+08:00'}, 'seq': 168, 'seq_in_channel': '168', 'src_guild_id': '7969749791337194879', 'timestamp': '2024-07-27T20:13:38+08:00'}
|
||||
self.direct_event_id_sample = "DIRECT_MESSAGE_CREATE:e4c09708-781d-44d0-b8cf-34bf3d4e2e64"
|
||||
|
||||
def create_random_group_message(self):
|
||||
@@ -42,4 +42,13 @@ class MockQQOfficialMessage():
|
||||
)
|
||||
return mocked
|
||||
|
||||
def create_msg(self, text: str):
|
||||
sample = self.group_plain_text_sample.copy()
|
||||
sample['content'] = text
|
||||
mocked = botpy.message.Message(
|
||||
api=None,
|
||||
event_id=self.group_event_id_sample,
|
||||
data=sample
|
||||
)
|
||||
return mocked
|
||||
|
||||
|
||||
+89
-4
@@ -8,11 +8,14 @@ from tests.mocks.onebot import MockOneBotMessage
|
||||
from astrbot.bootstrap import AstrBotBootstrap
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
from model.provider.openai_official import ProviderOpenAIOfficial
|
||||
from type.astrbot_message import *
|
||||
from type.message_event import *
|
||||
from SparkleLogging.utils.core import LogManager
|
||||
from logging import Formatter
|
||||
|
||||
from util.cmd_config import QQOfficialPlatformConfig, AiocqhttpPlatformConfig
|
||||
|
||||
logger = LogManager.GetLogger(
|
||||
log_name='astrbot',
|
||||
out_to_console=True,
|
||||
@@ -22,10 +25,23 @@ pytest_plugins = ('pytest_asyncio',)
|
||||
|
||||
os.environ['TEST_MODE'] = 'on'
|
||||
bootstrap = AstrBotBootstrap()
|
||||
asyncio.run(bootstrap.run())
|
||||
|
||||
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler)
|
||||
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler)
|
||||
llm_config = bootstrap.context.config_helper.llm[0]
|
||||
llm_config.api_base = os.environ['OPENAI_API_BASE']
|
||||
llm_config.key = [os.environ['OPENAI_API_KEY']]
|
||||
llm_config.model_config.model = os.environ['LLM_MODEL']
|
||||
llm_config.model_config.max_tokens = 1000
|
||||
llm_provider = ProviderOpenAIOfficial(llm_config)
|
||||
asyncio.run(bootstrap.run())
|
||||
bootstrap.message_handler.provider = llm_provider
|
||||
bootstrap.config_helper.wake_prefix = ["/"]
|
||||
bootstrap.config_helper.admins_id = ["905617992"]
|
||||
|
||||
for p_config in bootstrap.context.config_helper.platform:
|
||||
if isinstance(p_config, QQOfficialPlatformConfig):
|
||||
qq_official = QQOfficial(bootstrap.context, bootstrap.message_handler, p_config)
|
||||
elif isinstance(p_config, AiocqhttpPlatformConfig):
|
||||
aiocqhttp = AIOCQHTTP(bootstrap.context, bootstrap.message_handler, p_config)
|
||||
|
||||
class TestBasicMessageHandle():
|
||||
@pytest.mark.asyncio
|
||||
@@ -62,4 +78,73 @@ class TestBasicMessageHandle():
|
||||
event = MockOneBotMessage().create_random_direct_message()
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
print(ret)
|
||||
|
||||
class TestInteralCommandHsandle():
|
||||
def create(self, text: str):
|
||||
event = MockOneBotMessage().create_msg(text)
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
return abm
|
||||
|
||||
async def fast_test(self, text: str):
|
||||
abm = self.create(text)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(f"Command: {text}, Result: {ret.result_message}")
|
||||
return ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_save(self):
|
||||
abm = self.create("/websearch on")
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
assert bootstrap.context.config_helper.llm_settings.web_search \
|
||||
== bootstrap.config_helper.get("llm_settings")['web_search']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websearch(self):
|
||||
await self.fast_test("/websearch")
|
||||
await self.fast_test("/websearch on")
|
||||
await self.fast_test("/websearch off")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_help(self):
|
||||
await self.fast_test("/help")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_myid(self):
|
||||
await self.fast_test("/myid")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wake(self):
|
||||
await self.fast_test("/wake")
|
||||
await self.fast_test("/wake #")
|
||||
assert "#" in bootstrap.context.config_helper.wake_prefix
|
||||
assert "#" in bootstrap.context.config_helper.get("wake_prefix")
|
||||
await self.fast_test("#wake /")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleep(self):
|
||||
await self.fast_test("/provider")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update(self):
|
||||
await self.fast_test("/update")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_t2i(self):
|
||||
if not bootstrap.context.config_helper.t2i:
|
||||
abm = self.create("/t2i")
|
||||
await aiocqhttp.handle_msg(abm)
|
||||
await self.fast_test("/help")
|
||||
|
||||
class TestLLMChat():
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_chat(self):
|
||||
os.environ["TEST_LLM"] = "on"
|
||||
ret = await llm_provider.text_chat("Just reply `ok`", "test")
|
||||
print(ret)
|
||||
event = MockOneBotMessage().create_msg("Just reply `ok`")
|
||||
abm = aiocqhttp.convert_message(event)
|
||||
ret = await aiocqhttp.handle_msg(abm)
|
||||
print(ret)
|
||||
os.environ["TEST_LLM"] = "off"
|
||||
|
||||
+281
@@ -73,3 +73,284 @@ DEFAULT_CONFIG = {
|
||||
"ws_reverse_port": 0,
|
||||
}
|
||||
}
|
||||
|
||||
# 新版本配置文件,摈弃旧版本令人困惑的配置项 :D
|
||||
DEFAULT_CONFIG_VERSION_2 = {
|
||||
"config_version": 2,
|
||||
"platform": [
|
||||
{
|
||||
"id": "default",
|
||||
"name": "qq_official",
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"enable_group_c2c": True,
|
||||
"enable_guild_direct_message": True,
|
||||
},
|
||||
{
|
||||
"id": "default",
|
||||
"name": "nakuru",
|
||||
"enable": False,
|
||||
"host": "172.0.0.1",
|
||||
"port": 5700,
|
||||
"websocket_port": 6700,
|
||||
"enable_group": True,
|
||||
"enable_guild": True,
|
||||
"enable_direct_message": True,
|
||||
"enable_group_increase": True,
|
||||
},
|
||||
{
|
||||
"id": "default",
|
||||
"name": "aiocqhttp",
|
||||
"enable": False,
|
||||
"ws_reverse_host": "",
|
||||
"ws_reverse_port": 6199,
|
||||
}
|
||||
],
|
||||
"platform_settings": {
|
||||
"unique_session": False,
|
||||
"welcome_message_when_join": "",
|
||||
"rate_limit": {
|
||||
"time": 60,
|
||||
"count": 30,
|
||||
},
|
||||
"reply_prefix": "",
|
||||
"forward_threshold": 200, # 转发消息的阈值
|
||||
},
|
||||
"llm": [
|
||||
{
|
||||
"id": "default",
|
||||
"name": "openai",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"prompt_prefix": "",
|
||||
"default_personality": "",
|
||||
"model_config": {
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 6000,
|
||||
"temperature": 0.9,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
},
|
||||
"image_generation_model_config": {
|
||||
"enable": True,
|
||||
"model": "dall-e-3",
|
||||
"size": "1024x1024",
|
||||
"style": "vivid",
|
||||
"quality": "standard",
|
||||
}
|
||||
},
|
||||
],
|
||||
"llm_settings": {
|
||||
"wake_prefix": "",
|
||||
"web_search": False,
|
||||
},
|
||||
"content_safety": {
|
||||
"baidu_aip": {
|
||||
"enable": False,
|
||||
"app_id": "",
|
||||
"api_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
"internal_keywords": {
|
||||
"enable": True,
|
||||
"extra_keywords": [],
|
||||
}
|
||||
},
|
||||
"wake_prefix": [],
|
||||
"t2i": True,
|
||||
"dump_history_interval": 10,
|
||||
"admins_id": [],
|
||||
"https_proxy": "",
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "",
|
||||
"password": "",
|
||||
},
|
||||
}
|
||||
|
||||
# 这个是用于迁移旧版本配置文件的映射表
|
||||
MAPPINGS_1_2 = [
|
||||
[["qqbot", "enable"], ["platform", 0, "enable"]],
|
||||
[["qqbot", "appid"], ["platform", 0, "appid"]],
|
||||
[["qqbot", "token"], ["platform", 0, "secret"]],
|
||||
[["qqofficial_enable_group_message"], ["platform", 0, "enable_group_c2c"]],
|
||||
[["direct_message_mode"], ["platform", 0, "enable_guild_direct_message"]],
|
||||
[["gocqbot", "enable"], ["platform", 1, "enable"]],
|
||||
[["gocq_host"], ["platform", 1, "host"]],
|
||||
[["gocq_http_port"], ["platform", 1, "port"]],
|
||||
[["gocq_websocket_port"], ["platform", 1, "websocket_port"]],
|
||||
[["gocq_react_group"], ["platform", 1, "enable_group"]],
|
||||
[["gocq_react_guild"], ["platform", 1, "enable_guild"]],
|
||||
[["gocq_react_friend"], ["platform", 1, "enable_direct_message"]],
|
||||
[["gocq_react_group_increase"], ["platform", 1, "enable_group_increase"]],
|
||||
[["aiocqhttp", "enable"], ["platform", 2, "enable"]],
|
||||
[["aiocqhttp", "ws_reverse_host"], ["platform", 2, "ws_reverse_host"]],
|
||||
[["aiocqhttp", "ws_reverse_port"], ["platform", 2, "ws_reverse_port"]],
|
||||
[["uniqueSessionMode"], ["platform_settings", "unique_session"]],
|
||||
[["qq_welcome"], ["platform_settings", "welcome_message_when_join"]],
|
||||
[["limit", "time"], ["platform_settings", "rate_limit", "time"]],
|
||||
[["limit", "count"], ["platform_settings", "rate_limit", "count"]],
|
||||
[["reply_prefix"], ["platform_settings", "reply_prefix"]],
|
||||
[["qq_forward_threshold"], ["platform_settings", "forward_threshold"]],
|
||||
|
||||
[["openai", "key"], ["llm", 0, "key"]],
|
||||
[["openai", "api_base"], ["llm", 0, "api_base"]],
|
||||
[["openai", "chatGPTConfigs", "model"], ["llm", 0, "model_config", "model"]],
|
||||
[["openai", "chatGPTConfigs", "max_tokens"], ["llm", 0, "model_config", "max_tokens"]],
|
||||
[["openai", "chatGPTConfigs", "temperature"], ["llm", 0, "model_config", "temperature"]],
|
||||
[["openai", "chatGPTConfigs", "top_p"], ["llm", 0, "model_config", "top_p"]],
|
||||
[["openai", "chatGPTConfigs", "frequency_penalty"], ["llm", 0, "model_config", "frequency_penalty"]],
|
||||
[["openai", "chatGPTConfigs", "presence_penalty"], ["llm", 0, "model_config", "presence_penalty"]],
|
||||
|
||||
[["default_personality_str"], ["llm", 0, "default_personality"]],
|
||||
[["llm_env_prompt"], ["llm", 0, "prompt_prefix"]],
|
||||
[["openai_image_generate", "model"], ["llm", 0, "image_generation_model_config", "model"]],
|
||||
[["openai_image_generate", "size"], ["llm", 0, "image_generation_model_config", "size"]],
|
||||
[["openai_image_generate", "style"], ["llm", 0, "image_generation_model_config", "style"]],
|
||||
[["openai_image_generate", "quality"], ["llm", 0, "image_generation_model_config", "quality"]],
|
||||
|
||||
[["llm_wake_prefix"], ["llm_settings", "wake_prefix"]],
|
||||
|
||||
[["baidu_aip", "enable"], ["content_safety", "baidu_aip", "enable"]],
|
||||
[["baidu_aip", "app_id"], ["content_safety", "baidu_aip", "app_id"]],
|
||||
[["baidu_aip", "api_key"], ["content_safety", "baidu_aip", "api_key"]],
|
||||
[["baidu_aip", "secret_key"], ["content_safety", "baidu_aip", "secret_key"]],
|
||||
|
||||
[["qq_pic_mode"], ["t2i"]],
|
||||
[["dump_history_interval"], ["dump_history_interval"]],
|
||||
[["other_admins"], ["admins_id"]],
|
||||
[["http_proxy"], ["http_proxy"]],
|
||||
[["https_proxy"], ["https_proxy"]],
|
||||
[["dashboard_username"], ["dashboard", "username"]],
|
||||
[["dashboard_password"], ["dashboard", "password"]],
|
||||
[["nick_qq"], ["wake_prefix"]],
|
||||
]
|
||||
|
||||
CONFIG_METADATA_2 = {
|
||||
"config_version": {"description": "配置版本", "type": "int"},
|
||||
"platform": {
|
||||
"description": "平台配置",
|
||||
"type": "list",
|
||||
"items": {
|
||||
"name": {"description": "平台名称", "type": "string"},
|
||||
"enable": {"description": "是否启用", "type": "bool"},
|
||||
"appid": {"description": "应用ID", "type": "string"},
|
||||
"secret": {"description": "应用密钥", "type": "string"},
|
||||
"enable_group_c2c": {"description": "启用群C2C", "type": "bool"},
|
||||
"enable_guild_direct_message": {"description": "启用公会直接消息", "type": "bool"},
|
||||
"host": {"description": "主机地址", "type": "string"},
|
||||
"port": {"description": "端口", "type": "int"},
|
||||
"websocket_port": {"description": "WebSocket端口", "type": "int"},
|
||||
"ws_reverse_host": {"description": "WebSocket反向主机", "type": "string"},
|
||||
"ws_reverse_port": {"description": "WebSocket反向端口", "type": "int"},
|
||||
"enable_group": {"description": "启用群组", "type": "bool"},
|
||||
"enable_guild": {"description": "启用公会", "type": "bool"},
|
||||
"enable_direct_message": {"description": "启用直接消息", "type": "bool"},
|
||||
"enable_group_increase": {"description": "启用群组增加", "type": "bool"},
|
||||
}
|
||||
},
|
||||
"platform_settings": {
|
||||
"description": "平台设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"unique_session": {"description": "唯一会话", "type": "bool"},
|
||||
"welcome_message_when_join": {"description": "加入时欢迎信息", "type": "string"},
|
||||
"rate_limit": {
|
||||
"description": "速率限制",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"time": {"description": "时间", "type": "int"},
|
||||
"count": {"description": "计数", "type": "int"},
|
||||
}
|
||||
},
|
||||
"reply_prefix": {"description": "回复前缀", "type": "string"},
|
||||
"forward_threshold": {"description": "转发消息的阈值", "type": "int"},
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"description": "大语言模型配置",
|
||||
"type": "list",
|
||||
"items": {
|
||||
"name": {"description": "模型名称", "type": "string"},
|
||||
"enable": {"description": "是否启用", "type": "bool"},
|
||||
"key": {"description": "密钥", "type": "list", "items": {"type": "string"}},
|
||||
"api_base": {"description": "API基础URL", "type": "string"},
|
||||
"prompt_prefix": {"description": "提示前缀", "type": "string"},
|
||||
"default_personality": {"description": "默认个性", "type": "string"},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {"description": "模型名称", "type": "string"},
|
||||
"max_tokens": {"description": "最大令牌数", "type": "int"},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
"frequency_penalty": {"description": "频率惩罚", "type": "float"},
|
||||
"presence_penalty": {"description": "存在惩罚", "type": "float"},
|
||||
}
|
||||
},
|
||||
"image_generation_model_config": {
|
||||
"description": "图像生成模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "是否启用", "type": "bool"},
|
||||
"model": {"description": "模型名称", "type": "string"},
|
||||
"size": {"description": "图像尺寸", "type": "string"},
|
||||
"style": {"description": "图像风格", "type": "string"},
|
||||
"quality": {"description": "图像质量", "type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"llm_settings": {
|
||||
"description": "大语言模型设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"wake_prefix": {"description": "唤醒前缀", "type": "string"},
|
||||
"web_search": {"description": "启用网络搜索", "type": "bool"},
|
||||
}
|
||||
},
|
||||
"content_safety": {
|
||||
"description": "内容安全配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"baidu_aip": {
|
||||
"description": "百度AI平台配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "是否启用", "type": "bool"},
|
||||
"app_id": {"description": "应用ID", "type": "string"},
|
||||
"api_key": {"description": "API密钥", "type": "string"},
|
||||
"secret_key": {"description": "秘密密钥", "type": "string"},
|
||||
}
|
||||
},
|
||||
"internal_keywords": {
|
||||
"description": "内部关键词过滤",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "是否启用", "type": "bool"},
|
||||
"extra_keywords": {"description": "额外关键词", "type": "list", "items": {"type": "string"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"wake_prefix": {"description": "唤醒前缀列表", "type": "list", "items": {"type": "string"}},
|
||||
"t2i": {"description": "文本转图像功能", "type": "bool"},
|
||||
"dump_history_interval": {"description": "历史记录转储间隔", "type": "int"},
|
||||
"admins_id": {"description": "管理员ID列表", "type": "list", "items": {"type": "int"}},
|
||||
"https_proxy": {"description": "HTTPS代理", "type": "string"},
|
||||
"http_proxy": {"description": "HTTP代理", "type": "string"},
|
||||
"dashboard": {
|
||||
"description": "仪表盘配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "是否启用", "type": "bool"},
|
||||
"username": {"description": "用户名", "type": "string"},
|
||||
"password": {"description": "密码", "type": "string"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
+9
-9
@@ -3,7 +3,7 @@ from asyncio import Task
|
||||
from type.register import *
|
||||
from typing import List, Awaitable
|
||||
from logging import Logger
|
||||
from util.cmd_config import CmdConfig
|
||||
from util.cmd_config import AstrBotConfig
|
||||
from util.t2i.renderer import TextToImageRenderer
|
||||
from util.updator.astrbot_updator import AstrBotUpdator
|
||||
from util.image_uploader import ImageUploader
|
||||
@@ -21,19 +21,19 @@ class Context:
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.running = True
|
||||
self.logger: Logger = None
|
||||
self.base_config: dict = None # 配置(期望启动机器人后是不变的)
|
||||
self.config_helper: CmdConfig = None
|
||||
self.config_helper: AstrBotConfig = None
|
||||
self.cached_plugins: List[RegisteredPlugin] = [] # 缓存的插件
|
||||
self.platforms: List[RegisteredPlatform] = []
|
||||
self.llms: List[RegisteredLLM] = []
|
||||
self.default_personality: dict = None
|
||||
|
||||
self.unique_session = False # 独立会话
|
||||
self.version: str = None # 机器人版本
|
||||
self.nick: tuple = None # gocq 的唤醒词
|
||||
self.t2i_mode = False
|
||||
self.web_search = False # 是否开启了网页搜索
|
||||
# self.unique_session = False # 独立会话
|
||||
# self.version: str = None # 机器人版本
|
||||
# self.nick: tuple = None # gocq 的唤醒词
|
||||
# self.t2i_mode = False
|
||||
# self.web_search = False # 是否开启了网页搜索
|
||||
|
||||
self.metrics_uploader = None
|
||||
self.updator: AstrBotUpdator = None
|
||||
@@ -48,7 +48,7 @@ class Context:
|
||||
self.running = True
|
||||
|
||||
# useless
|
||||
self.reply_prefix = ""
|
||||
# self.reply_prefix = ""
|
||||
|
||||
def register_commands(self,
|
||||
plugin_name: str,
|
||||
|
||||
+250
-46
@@ -1,33 +1,242 @@
|
||||
import os
|
||||
import json
|
||||
from type.config import DEFAULT_CONFIG
|
||||
import logging
|
||||
from type.config import DEFAULT_CONFIG, DEFAULT_CONFIG_VERSION_2, MAPPINGS_1_2
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
cpath = "data/cmd_config.json"
|
||||
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
def check_exist():
|
||||
if not os.path.exists(cpath):
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump({}, f, ensure_ascii=False)
|
||||
f.flush()
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
time: int = 60
|
||||
count: int = 30
|
||||
|
||||
@dataclass
|
||||
class PlatformSettings:
|
||||
unique_session: bool = False
|
||||
welcome_message_when_join: str = ""
|
||||
rate_limit: RateLimit = field(default_factory=RateLimit)
|
||||
reply_prefix: str = ""
|
||||
forward_threshold: int = 200
|
||||
|
||||
def __post_init__(self):
|
||||
self.rate_limit = RateLimit(**self.rate_limit)
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig():
|
||||
id: str = ""
|
||||
name: str = ""
|
||||
enable: bool = False
|
||||
|
||||
@dataclass
|
||||
class QQOfficialPlatformConfig(PlatformConfig):
|
||||
appid: str = ""
|
||||
secret: str = ""
|
||||
enable_group_c2c: bool = True
|
||||
enable_guild_direct_message: bool = True
|
||||
|
||||
@dataclass
|
||||
class NakuruPlatformConfig(PlatformConfig):
|
||||
host: str = "172.0.0.1",
|
||||
port: int = 5700,
|
||||
websocket_port: int = 6700,
|
||||
enable_group: bool = True,
|
||||
enable_guild: bool = True,
|
||||
enable_direct_message: bool = True,
|
||||
enable_group_increase: bool = True
|
||||
|
||||
@dataclass
|
||||
class AiocqhttpPlatformConfig(PlatformConfig):
|
||||
ws_reverse_host: str = ""
|
||||
ws_reverse_port: int = 6199
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
model: str = "gpt-4o"
|
||||
max_tokens: int = 6000
|
||||
temperature: float = 0.9
|
||||
top_p: float = 1
|
||||
frequency_penalty: float = 0
|
||||
presence_penalty: float = 0
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationModelConfig:
|
||||
enable: bool = True
|
||||
model: str = "dall-e-3"
|
||||
size: str = "1024x1024"
|
||||
style: str = "vivid"
|
||||
quality: str = "standard"
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
id: str = ""
|
||||
name: str = "openai"
|
||||
enable: bool = True
|
||||
key: List[str] = field(default_factory=list)
|
||||
api_base: str = ""
|
||||
prompt_prefix: str = ""
|
||||
default_personality: str = ""
|
||||
model_config: ModelConfig = field(default_factory=ModelConfig)
|
||||
image_generation_model_config: Optional[ImageGenerationModelConfig] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.model_config = ModelConfig(**self.model_config)
|
||||
if self.image_generation_model_config:
|
||||
self.image_generation_model_config = ImageGenerationModelConfig(**self.image_generation_model_config)
|
||||
@dataclass
|
||||
class LLMSettings:
|
||||
wake_prefix: str = ""
|
||||
web_search: bool = False
|
||||
|
||||
@dataclass
|
||||
class BaiduAIPConfig:
|
||||
enable: bool = False
|
||||
app_id: str = ""
|
||||
api_key: str = ""
|
||||
secret_key: str = ""
|
||||
|
||||
@dataclass
|
||||
class InternalKeywordsConfig:
|
||||
enable: bool = True
|
||||
extra_keywords: List[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class ContentSafetyConfig:
|
||||
baidu_aip: BaiduAIPConfig = field(default_factory=BaiduAIPConfig)
|
||||
internal_keywords: InternalKeywordsConfig = field(default_factory=InternalKeywordsConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
self.baidu_aip = BaiduAIPConfig(**self.baidu_aip)
|
||||
self.internal_keywords = InternalKeywordsConfig(**self.internal_keywords)
|
||||
|
||||
@dataclass
|
||||
class DashboardConfig:
|
||||
enable: bool = True
|
||||
username: str = ""
|
||||
password: str = ""
|
||||
|
||||
@dataclass
|
||||
class AstrBotConfig():
|
||||
config_version: int = 2
|
||||
platform_settings: PlatformSettings = field(default_factory=PlatformSettings)
|
||||
llm: List[LLMConfig] = field(default_factory=list)
|
||||
llm_settings: LLMSettings = field(default_factory=LLMSettings)
|
||||
content_safety: ContentSafetyConfig = field(default_factory=ContentSafetyConfig)
|
||||
t2i: bool = True
|
||||
dump_history_interval: int = 10
|
||||
admins_id: List[str] = field(default_factory=list)
|
||||
https_proxy: str = ""
|
||||
http_proxy: str = ""
|
||||
dashboard: DashboardConfig = field(default_factory=DashboardConfig)
|
||||
platform: List[PlatformConfig] = field(default_factory=list)
|
||||
wake_prefix: List[str] = field(default_factory=list)
|
||||
|
||||
class CmdConfig():
|
||||
def __init__(self) -> None:
|
||||
self.cached_config: dict = {}
|
||||
self.init_configs()
|
||||
|
||||
# compability
|
||||
if isinstance(self.wake_prefix, str):
|
||||
self.wake_prefix = [self.wake_prefix]
|
||||
|
||||
def load_from_dict(self, data: Dict):
|
||||
'''从字典中加载配置到对象。
|
||||
|
||||
@note: 适用于 version 2 配置文件。
|
||||
'''
|
||||
self.config_version=data.get("version", 2)
|
||||
self.platform=[]
|
||||
for p in data.get("platform", []):
|
||||
if 'name' not in p:
|
||||
logger.warning("A platform config missing name, skipping.")
|
||||
continue
|
||||
if p["name"] == "qq_official":
|
||||
self.platform.append(QQOfficialPlatformConfig(**p))
|
||||
elif p["name"] == "nakuru":
|
||||
self.platform.append(NakuruPlatformConfig(**p))
|
||||
elif p["name"] == "aiocqhttp":
|
||||
self.platform.append(AiocqhttpPlatformConfig(**p))
|
||||
else:
|
||||
self.platform.append(PlatformConfig(**p))
|
||||
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
|
||||
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
|
||||
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
|
||||
self.content_safety=ContentSafetyConfig(**data.get("content_safety", {}))
|
||||
self.t2i=data.get("t2i", True)
|
||||
self.dump_history_interval=data.get("dump_history_interval", 10)
|
||||
self.admins_id=data.get("admins_id", [])
|
||||
self.https_proxy=data.get("https_proxy", "")
|
||||
self.http_proxy=data.get("http_proxy", "")
|
||||
self.dashboard=DashboardConfig(**data.get("dashboard", {}))
|
||||
self.wake_prefix=data.get("wake_prefix", [])
|
||||
|
||||
def migrate_config_1_2(self, old: dict) -> dict:
|
||||
'''将配置文件从版本 1 迁移至版本 2'''
|
||||
logger.info("正在更新配置文件到 version 2...")
|
||||
new_config = DEFAULT_CONFIG_VERSION_2
|
||||
mappings = MAPPINGS_1_2
|
||||
|
||||
def set_nested_value(d, keys, value):
|
||||
cursor = d
|
||||
for key in keys[:-1]:
|
||||
cursor = cursor[key]
|
||||
cursor[keys[-1]] = value
|
||||
|
||||
for old_path, new_path in mappings:
|
||||
value = old
|
||||
try:
|
||||
for key in old_path:
|
||||
value = value[key] # soooooo convenient!!
|
||||
set_nested_value(new_config, new_path, value)
|
||||
except KeyError:
|
||||
# 如果旧配置中没有这个键,跳过,即使用新配置的默认值
|
||||
continue
|
||||
|
||||
logger.info("配置文件更新完成。")
|
||||
return new_config
|
||||
|
||||
def flush_config(self, config: dict = None):
|
||||
'''将配置写入文件, 如果没有传入配置,则写入默认配置'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(config if config else DEFAULT_CONFIG_VERSION_2, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def save_config(self):
|
||||
'''将现存配置写入文件'''
|
||||
self.flush_config(self.to_dict())
|
||||
|
||||
def init_configs(self):
|
||||
'''
|
||||
初始化必需的配置项
|
||||
'''
|
||||
self.init_config_items(DEFAULT_CONFIG)
|
||||
'''初始化必需的配置项'''
|
||||
config = None
|
||||
|
||||
if not self.check_exist():
|
||||
self.flush_config()
|
||||
config = DEFAULT_CONFIG_VERSION_2
|
||||
else:
|
||||
config = self.get_all()
|
||||
# check if the config is outdated
|
||||
if 'config_version' not in config: # version 1
|
||||
config = self.migrate_config_1_2(config)
|
||||
self.flush_config(config)
|
||||
|
||||
_tag = False
|
||||
for key, val in DEFAULT_CONFIG_VERSION_2.items():
|
||||
if key not in config:
|
||||
config[key] = val
|
||||
_tag = True
|
||||
if _tag:
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
@staticmethod
|
||||
def get(key, default=None):
|
||||
self.load_from_dict(config)
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
'''
|
||||
从文件系统中直接获取配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
if key in d:
|
||||
return d[key]
|
||||
@@ -38,30 +247,29 @@ class CmdConfig():
|
||||
'''
|
||||
从文件系统中获取所有配置
|
||||
'''
|
||||
check_exist()
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
conf_str = f.read()
|
||||
if conf_str.startswith(u'/ufeff'): # remove BOM
|
||||
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
|
||||
if not conf_str:
|
||||
return {}
|
||||
conf = json.loads(conf_str)
|
||||
return conf
|
||||
|
||||
def put(self, key, value):
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
d[key] = value
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
self.cached_config[key] = value
|
||||
|
||||
@staticmethod
|
||||
def put_by_dot_str(key: str, value):
|
||||
'''
|
||||
根据点分割的字符串,将值写入配置文件
|
||||
'''
|
||||
with open(cpath, "r", encoding="utf-8-sig") as f:
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self)
|
||||
|
||||
def put_by_dot_str(self, key: str, value):
|
||||
'''根据点分割的字符串,将值写入配置文件'''
|
||||
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
|
||||
d = json.load(f)
|
||||
_d = d
|
||||
_ks = key.split(".")
|
||||
@@ -70,23 +278,19 @@ class CmdConfig():
|
||||
_d[_ks[i]] = value
|
||||
else:
|
||||
_d = _d[_ks[i]]
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(d, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def init_config_items(self, d: dict):
|
||||
conf = self.get_all()
|
||||
|
||||
def update_by_path(self, path: List):
|
||||
'''根据路径更新配置文件。
|
||||
|
||||
if not self.cached_config:
|
||||
self.cached_config = conf
|
||||
这个方法首先会更新缓存在内存中的配置,然后再写入文件。
|
||||
'''
|
||||
|
||||
for key in path:
|
||||
if key not in self:
|
||||
raise KeyError(f"Key {key} not found in config.")
|
||||
|
||||
_tag = False
|
||||
|
||||
for key, val in d.items():
|
||||
if key not in conf:
|
||||
conf[key] = val
|
||||
_tag = True
|
||||
if _tag:
|
||||
with open(cpath, "w", encoding="utf-8-sig") as f:
|
||||
json.dump(conf, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
def check_exist(self) -> bool:
|
||||
return os.path.exists(ASTRBOT_CONFIG_PATH)
|
||||
+9
-10
@@ -1,16 +1,15 @@
|
||||
import json, os
|
||||
from util.cmd_config import CmdConfig
|
||||
import json, os, shutil
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
def try_migrate_config():
|
||||
'''
|
||||
将 cmd_config.json 迁移至 data/cmd_config.json (如果存在的话)
|
||||
'''
|
||||
if os.path.exists("cmd_config.json"):
|
||||
with open("cmd_config.json", "r", encoding="utf-8-sig") as f:
|
||||
data = json.load(f)
|
||||
with open("data/cmd_config.json", "w", encoding="utf-8-sig") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
if os.path.exists("cmd_config.json") and not os.path.exists("data/cmd_config.json"):
|
||||
try:
|
||||
os.remove("cmd_config.json")
|
||||
except Exception as e:
|
||||
pass
|
||||
shutil.move("cmd_config.json", "data/cmd_config.json")
|
||||
except:
|
||||
logger.error("迁移 cmd_config.json 失败。AstrBot 将不会读取配置文件,你可以手动将 cmd_config.json 迁移至 data/cmd_config.json。")
|
||||
|
||||
+2
-1
@@ -5,6 +5,7 @@ import sys
|
||||
|
||||
from type.types import Context
|
||||
from collections import defaultdict
|
||||
from type.config import VERSION
|
||||
|
||||
class MetricUploader():
|
||||
def __init__(self, context: Context) -> None:
|
||||
@@ -49,7 +50,7 @@ class MetricUploader():
|
||||
try:
|
||||
res = {
|
||||
"stat_version": "moon",
|
||||
"version": context.version, # 版本号
|
||||
"version": VERSION, # 版本号
|
||||
"platform_stats": self.platform_stats, # 过去 30 分钟各消息平台交互消息数
|
||||
"llm_stats": self.llm_stats,
|
||||
"plugin_stats": self.plugin_stats,
|
||||
|
||||
@@ -7,6 +7,6 @@ Platform类是消息平台的抽象类,定义了消息平台的基本接口。
|
||||
|
||||
from model.platform import Platform
|
||||
|
||||
from model.platform.qq_nakuru import QQGOCQ
|
||||
from model.platform.qq_nakuru import QQNakuru
|
||||
from model.platform.qq_official import QQOfficial
|
||||
from model.platform.qq_aiocqhttp import AIOCQHTTP
|
||||
Reference in New Issue
Block a user