feat: 使用 jwt 用于管理面板鉴权

This commit is contained in:
Soulter
2024-12-03 19:35:07 +08:00
parent 4a52779d09
commit 7abe90f2ac
17 changed files with 132 additions and 73 deletions
+2 -1
View File
@@ -5,4 +5,5 @@ from astrbot.core.config.default import DB_PATH
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
db_helper = SQLiteDatabase(DB_PATH)
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+1
View File
@@ -142,6 +142,7 @@ class ProjectATRI:
long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory)
active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage)
persona: str = ""
split_response: bool = True
embedding_provider_id: str = ""
summarize_provider_id: str = ""
chat_provider_id: str = ""
+18 -4
View File
@@ -3,7 +3,7 @@
'''
VERSION = '3.4.0'
DB_PATH = 'data/data_v2.db'
DB_PATH = 'data/data_v3.db'
# LLM 提供商配置模板
PROVIDER_CONFIG_TEMPLATE = {
@@ -144,8 +144,8 @@ DEFAULT_CONFIG_VERSION_2 = {
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "",
"password": "",
"username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9",
},
"log_level": "INFO",
"t2i_endpoint": "",
@@ -160,7 +160,12 @@ DEFAULT_CONFIG_VERSION_2 = {
"active_message": {
"enable": False,
},
"vision": {
"enable": False,
"provider_id_or_ofa_model_path": "",
},
"persona": "",
"split_response": True,
"embedding_provider_id": "",
"summarize_provider_id": "",
"chat_provider_id": "",
@@ -188,7 +193,7 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"},
"qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"},
"qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"},
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"},
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "obvious_hint": True, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"},
}
},
"platform_settings": {
@@ -322,6 +327,15 @@ CONFIG_METADATA_2 = {
"enable": {"description": "启用", "type": "bool"},
}
},
"vision": {
"description": "视觉理解",
"type": "object",
"items": {
"enable": {"description": "启用", "type": "bool"},
"provider_id_or_ofa_model_path": {"description": "提供商 ID 或 OFA 模型路径", "type": "string", "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。"},
}
},
"split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的事件间隔。默认启用。"},
"persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True},
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
+1
View File
@@ -417,6 +417,7 @@ class Unknown(BaseMessageComponent):
ComponentTypes = {
"plain": Plain,
"text": Plain,
"face": Face,
"record": Record,
"video": Video,
+1
View File
@@ -2,6 +2,7 @@ import os
import ssl
import shutil
import socket
import ifaddr
import time
import aiohttp
import base64
+33 -12
View File
@@ -1,13 +1,14 @@
from .route import Route, Response
import jwt, datetime
from .route import Route, Response, RouteContext
from quart import Quart, request
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import WEBUI_SK
class AuthRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext) -> None:
super().__init__(context)
self.routes = {
'/auth/login': ('POST', self.login),
'/auth/password/reset': ('POST', self.reset_password),
'/auth/account/edit': ('POST', self.edit_account),
}
self.register_routes()
@@ -17,17 +18,37 @@ class AuthRoute(Route):
post_data = await request.json
if post_data["username"] == username and post_data["password"] == password:
return Response().ok({
"token": "astrbot-test-token",
"token": self.generate_jwt(username),
"username": username
}).__dict__
else:
return Response().error("用户名或密码错误").__dict__
async def reset_password(self):
async def edit_account(self):
password = self.config.dashboard.password
post_data = await request.json
if post_data["password"] == password:
self.config.dashboard.password = post_data['new_password']
return Response().ok(None).__dict__
else:
return Response().error("原密码错误").__dict__
if post_data["password"] != password:
return Response().error("原密码错误").__dict__
new_pwd = post_data.get('new_password', None)
new_username = post_data.get('new_username', None)
if not new_pwd and not new_username:
return Response().error("新用户名和新密码不能同时为空,你改了个寂寞").__dict__
if new_pwd:
self.config.dashboard.password = new_pwd
if new_username:
self.config.dashboard.username = new_username
self.config.flush_config()
return Response().ok(None, "修改成功").__dict__
def generate_jwt(self, username):
payload = {
"username": username,
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=30)
}
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256")
return token
+3 -3
View File
@@ -1,5 +1,5 @@
import os, json
from .route import Route, Response
from .route import Route, Response, RouteContext
from quart import Quart, request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -87,8 +87,8 @@ def save_extension_config(post_config: dict):
update_config(namespace, key, value)
class ConfigRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.config_key_dont_show = ['dashboard', 'config_version']
self.core_lifecycle = core_lifecycle
self.routes = {
+3 -5
View File
@@ -1,13 +1,11 @@
import asyncio
from quart import websocket
from quart import Quart
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger, LogBroker
from .route import Route, Response
from .route import Route, Response, RouteContext
class LogRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, log_broker: LogBroker) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
super().__init__(context)
self.log_broker = log_broker
self.app.add_url_rule('/api/live-log', view_func=self.log, methods=['GET'], websocket=True)
+3 -4
View File
@@ -1,14 +1,13 @@
import threading, traceback, uuid
from .route import Route, Response
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import Quart, request
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.plugin.plugin_manager import PluginManager
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class PluginRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
super().__init__(context)
self.routes = {
'/plugin/get': ('GET', self.get_plugins),
'/plugin/install': ('POST', self.install_plugin),
+7 -3
View File
@@ -2,11 +2,15 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from dataclasses import dataclass
from quart import Quart
@dataclass
class RouteContext:
config: AstrBotConfig
app: Quart
class Route():
def __init__(self, config: AstrBotConfig, app: Quart):
self.app = app
self.config = config
def __init__(self, context: RouteContext):
self.app = context.app
self.config = context.config
def register_routes(self):
for route, (method, func) in self.routes.items():
+3 -3
View File
@@ -1,5 +1,5 @@
import traceback, psutil, time, aiohttp
from .route import Route, Response
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import Quart, request
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -8,8 +8,8 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.config import VERSION
class StatRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.routes = {
'/stat/get': ('GET', self.get_stat),
'/stat/version': ('GET', self.get_version),
+4 -7
View File
@@ -1,12 +1,9 @@
from .route import Route
from quart import Quart
from astrbot.core.config.astrbot_config import AstrBotConfig
from .route import Route, RouteContext
class StaticFileRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext) -> None:
super().__init__(context)
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default']
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console']
for i in index_:
self.app.add_url_rule(i, view_func=self.index)
+3 -4
View File
@@ -1,13 +1,12 @@
import threading, traceback
from .route import Route, Response
from .route import Route, Response, RouteContext
from quart import Quart, request
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
class UpdateRoute(Route):
def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None:
super().__init__(config, app)
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
super().__init__(context)
self.routes = {
'/update/check': ('GET', self.check_update),
'/update/do': ('POST', self.update_project),
+40 -17
View File
@@ -1,36 +1,56 @@
import logging
import logging, jwt
import asyncio, os
from quart import Quart
from quart import Quart, request
from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
from astrbot.core import logger
from .routes.route import RouteContext, Response
from astrbot.core import logger, WEBUI_SK
from astrbot.core.db import BaseDatabase
from astrbot.core.plugin.plugin_manager import PluginManager
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.io import get_local_ip_addresses
from astrbot.core.config import AstrBotConfig
from astrbot.core.db import BaseDatabase
DATAPATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data"))
class AstrBotDashboard():
def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None:
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/dist"))
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
logger.info(f"Dashboard data path: {self.data_path}")
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
self.app.json.sort_keys = False
self.app.before_request(self.auth_middleware)
# token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler)
self.context = RouteContext(self.config, self.app)
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator)
self.sr = StatRoute(self.context, db, core_lifecycle)
self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager)
self.cr = ConfigRoute(self.context, core_lifecycle)
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
async def auth_middleware(self):
if not request.path.startswith("/api"):
return
if request.path == "/api/auth/login":
return
# claim jwt
token = request.headers.get("Authorization")
if token.startswith("Bearer "):
token = token[7:]
if not token:
return Response().error("未授权").__dict__
try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
return Response().error("Token 过期").__dict__
except jwt.InvalidTokenError:
return Response().error("Token 无效").__dict__
self.ar = AuthRoute(self.config, self.app)
self.ur = UpdateRoute(self.config, self.app, core_lifecycle.astrbot_updator)
self.sr = StatRoute(self.config, self.app, db, core_lifecycle)
self.pr = PluginRoute(self.config, self.app, core_lifecycle, core_lifecycle.plugin_manager)
self.cr = ConfigRoute(self.config, self.app, core_lifecycle)
self.lr = LogRoute(self.config, self.app, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.config, self.app)
async def shutdown_trigger_placeholder(self):
while not self.core_lifecycle.event_queue.closed:
await asyncio.sleep(1)
@@ -38,5 +58,8 @@ class AstrBotDashboard():
def run(self):
ip_addr = get_local_ip_addresses()
logger.info(f"\n-----\n🌈 管理面板已启动,可访问 \n1. http://{ip_addr}:6185\n2. http://localhost:6185 登录。\n------")
logger.info(f"""🌈 管理面板已启动,可访问
1. http://{ip_addr}:6185
2. http://localhost:6185
登录。默认用户名和密码是 astrbot。""")
return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder)
+4 -4
View File
@@ -148,11 +148,11 @@ class Main:
if self.provider_config.prompt_prefix:
event.message_str = self.provider_config.prompt_prefix + event.message_str
image_url = None
image_urls = []
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
break
image_urls.append(image_url)
tool_use_flag = False
llm_result = None
@@ -210,7 +210,7 @@ class Main:
llm_result = await self.provider.text_chat(
prompt=event.message_str,
session_id=event.session_id,
image_url=image_url
image_urls=image_urls
)
await Metric.upload(llm_tick=1, llm_name=self.provider.get_model(), llm_api_base=self.provider.base_url)
except BadRequestError as e:
@@ -231,7 +231,7 @@ class Main:
llm_result = await self.provider.text_chat(
prompt=event.message_str,
session_id=event.session_id,
image_url=image_url
image_urls=image_urls
)
except BaseException as e:
logger.error(traceback.format_exc())
@@ -50,9 +50,6 @@ class ProviderOpenAIOfficial(Provider):
'''
将图片转换为 base64
'''
if image_url.startswith("http"):
image_url = await download_image_by_url(image_url)
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
@@ -98,12 +95,14 @@ class ProviderOpenAIOfficial(Provider):
'''
组装上下文。
'''
if image_urls:
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
for image_url in image_urls:
base_64_image = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": base_64_image}})
if image_url.startswith("http"):
image_data = image_url
else:
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:
return {"role": "user","content": text}
+1
View File
@@ -13,3 +13,4 @@ psutil
lxml_html_clean
colorlog
aiocqhttp
pyjwt