feat: 使用 jwt 用于管理面板鉴权
This commit is contained in:
@@ -5,4 +5,5 @@ from astrbot.core.config.default import DB_PATH
|
||||
|
||||
html_renderer = HtmlRenderer()
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
@@ -142,6 +142,7 @@ class ProjectATRI:
|
||||
long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory)
|
||||
active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage)
|
||||
persona: str = ""
|
||||
split_response: bool = True
|
||||
embedding_provider_id: str = ""
|
||||
summarize_provider_id: str = ""
|
||||
chat_provider_id: str = ""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
'''
|
||||
|
||||
VERSION = '3.4.0'
|
||||
DB_PATH = 'data/data_v2.db'
|
||||
DB_PATH = 'data/data_v3.db'
|
||||
|
||||
# LLM 提供商配置模板
|
||||
PROVIDER_CONFIG_TEMPLATE = {
|
||||
@@ -144,8 +144,8 @@ DEFAULT_CONFIG_VERSION_2 = {
|
||||
"http_proxy": "",
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "",
|
||||
"password": "",
|
||||
"username": "astrbot",
|
||||
"password": "77b90590a8945a7d36c963981a307dc9",
|
||||
},
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
@@ -160,7 +160,12 @@ DEFAULT_CONFIG_VERSION_2 = {
|
||||
"active_message": {
|
||||
"enable": False,
|
||||
},
|
||||
"vision": {
|
||||
"enable": False,
|
||||
"provider_id_or_ofa_model_path": "",
|
||||
},
|
||||
"persona": "",
|
||||
"split_response": True,
|
||||
"embedding_provider_id": "",
|
||||
"summarize_provider_id": "",
|
||||
"chat_provider_id": "",
|
||||
@@ -188,7 +193,7 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"},
|
||||
"qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"},
|
||||
"qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"},
|
||||
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"},
|
||||
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "obvious_hint": True, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"},
|
||||
}
|
||||
},
|
||||
"platform_settings": {
|
||||
@@ -322,6 +327,15 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
}
|
||||
},
|
||||
"vision": {
|
||||
"description": "视觉理解",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {"description": "启用", "type": "bool"},
|
||||
"provider_id_or_ofa_model_path": {"description": "提供商 ID 或 OFA 模型路径", "type": "string", "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。"},
|
||||
}
|
||||
},
|
||||
"split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的事件间隔。默认启用。"},
|
||||
"persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True},
|
||||
"embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True},
|
||||
"summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True},
|
||||
|
||||
@@ -417,6 +417,7 @@ class Unknown(BaseMessageComponent):
|
||||
|
||||
ComponentTypes = {
|
||||
"plain": Plain,
|
||||
"text": Plain,
|
||||
"face": Face,
|
||||
"record": Record,
|
||||
"video": Video,
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import ssl
|
||||
import shutil
|
||||
import socket
|
||||
import ifaddr
|
||||
import time
|
||||
import aiohttp
|
||||
import base64
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from .route import Route, Response
|
||||
import jwt, datetime
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import Quart, request
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import WEBUI_SK
|
||||
|
||||
class AuthRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/auth/login': ('POST', self.login),
|
||||
'/auth/password/reset': ('POST', self.reset_password),
|
||||
'/auth/account/edit': ('POST', self.edit_account),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@@ -17,17 +18,37 @@ class AuthRoute(Route):
|
||||
post_data = await request.json
|
||||
if post_data["username"] == username and post_data["password"] == password:
|
||||
return Response().ok({
|
||||
"token": "astrbot-test-token",
|
||||
"token": self.generate_jwt(username),
|
||||
"username": username
|
||||
}).__dict__
|
||||
else:
|
||||
return Response().error("用户名或密码错误").__dict__
|
||||
|
||||
async def reset_password(self):
|
||||
async def edit_account(self):
|
||||
password = self.config.dashboard.password
|
||||
post_data = await request.json
|
||||
if post_data["password"] == password:
|
||||
self.config.dashboard.password = post_data['new_password']
|
||||
return Response().ok(None).__dict__
|
||||
else:
|
||||
return Response().error("原密码错误").__dict__
|
||||
|
||||
if post_data["password"] != password:
|
||||
return Response().error("原密码错误").__dict__
|
||||
|
||||
new_pwd = post_data.get('new_password', None)
|
||||
new_username = post_data.get('new_username', None)
|
||||
if not new_pwd and not new_username:
|
||||
return Response().error("新用户名和新密码不能同时为空,你改了个寂寞").__dict__
|
||||
|
||||
if new_pwd:
|
||||
self.config.dashboard.password = new_pwd
|
||||
if new_username:
|
||||
self.config.dashboard.username = new_username
|
||||
|
||||
self.config.flush_config()
|
||||
|
||||
return Response().ok(None, "修改成功").__dict__
|
||||
|
||||
def generate_jwt(self, username):
|
||||
payload = {
|
||||
"username": username,
|
||||
"exp": datetime.datetime.utcnow() + datetime.timedelta(days=30)
|
||||
}
|
||||
token = jwt.encode(payload, WEBUI_SK, algorithm="HS256")
|
||||
return token
|
||||
@@ -1,5 +1,5 @@
|
||||
import os, json
|
||||
from .route import Route, Response
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import Quart, request
|
||||
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -87,8 +87,8 @@ def save_extension_config(post_config: dict):
|
||||
update_config(namespace, key, value)
|
||||
|
||||
class ConfigRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
self.config_key_dont_show = ['dashboard', 'config_version']
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.routes = {
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import asyncio
|
||||
from quart import websocket
|
||||
from quart import Quart
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, LogBroker
|
||||
from .route import Route, Response
|
||||
from .route import Route, Response, RouteContext
|
||||
|
||||
class LogRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, log_broker: LogBroker) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.app.add_url_rule('/api/live-log', view_func=self.log, methods=['GET'], websocket=True)
|
||||
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import threading, traceback, uuid
|
||||
from .route import Route, Response
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import Quart, request
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.plugin.plugin_manager import PluginManager
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
class PluginRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/plugin/get': ('GET', self.get_plugins),
|
||||
'/plugin/install': ('POST', self.install_plugin),
|
||||
|
||||
@@ -2,11 +2,15 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from dataclasses import dataclass
|
||||
from quart import Quart
|
||||
|
||||
@dataclass
|
||||
class RouteContext:
|
||||
config: AstrBotConfig
|
||||
app: Quart
|
||||
|
||||
class Route():
|
||||
def __init__(self, config: AstrBotConfig, app: Quart):
|
||||
self.app = app
|
||||
self.config = config
|
||||
def __init__(self, context: RouteContext):
|
||||
self.app = context.app
|
||||
self.config = context.config
|
||||
|
||||
def register_routes(self):
|
||||
for route, (method, func) in self.routes.items():
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import traceback, psutil, time, aiohttp
|
||||
from .route import Route, Response
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import Quart, request
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -8,8 +8,8 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import VERSION
|
||||
|
||||
class StatRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/stat/get': ('GET', self.get_stat),
|
||||
'/stat/version': ('GET', self.get_version),
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from .route import Route
|
||||
from quart import Quart
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
|
||||
from .route import Route, RouteContext
|
||||
class StaticFileRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext) -> None:
|
||||
super().__init__(context)
|
||||
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default']
|
||||
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console']
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import threading, traceback
|
||||
from .route import Route, Response
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import Quart, request
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None:
|
||||
super().__init__(config, app)
|
||||
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/update/check': ('GET', self.check_update),
|
||||
'/update/do': ('POST', self.update_project),
|
||||
|
||||
+40
-17
@@ -1,36 +1,56 @@
|
||||
import logging
|
||||
import logging, jwt
|
||||
import asyncio, os
|
||||
from quart import Quart
|
||||
from quart import Quart, request
|
||||
from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
from astrbot.core import logger
|
||||
from .routes.route import RouteContext, Response
|
||||
from astrbot.core import logger, WEBUI_SK
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.plugin.plugin_manager import PluginManager
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
DATAPATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data"))
|
||||
|
||||
class AstrBotDashboard():
|
||||
def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None:
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.data_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/dist"))
|
||||
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
|
||||
logger.info(f"Dashboard data path: {self.data_path}")
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
self.app.json.sort_keys = False
|
||||
|
||||
self.app.before_request(self.auth_middleware)
|
||||
# token 用于验证请求
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
self.context = RouteContext(self.config, self.app)
|
||||
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator)
|
||||
self.sr = StatRoute(self.context, db, core_lifecycle)
|
||||
self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager)
|
||||
self.cr = ConfigRoute(self.context, core_lifecycle)
|
||||
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
||||
self.sfr = StaticFileRoute(self.context)
|
||||
self.ar = AuthRoute(self.context)
|
||||
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return
|
||||
if request.path == "/api/auth/login":
|
||||
return
|
||||
# claim jwt
|
||||
token = request.headers.get("Authorization")
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
if not token:
|
||||
return Response().error("未授权").__dict__
|
||||
try:
|
||||
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError:
|
||||
return Response().error("Token 过期").__dict__
|
||||
except jwt.InvalidTokenError:
|
||||
return Response().error("Token 无效").__dict__
|
||||
|
||||
|
||||
self.ar = AuthRoute(self.config, self.app)
|
||||
self.ur = UpdateRoute(self.config, self.app, core_lifecycle.astrbot_updator)
|
||||
self.sr = StatRoute(self.config, self.app, db, core_lifecycle)
|
||||
self.pr = PluginRoute(self.config, self.app, core_lifecycle, core_lifecycle.plugin_manager)
|
||||
self.cr = ConfigRoute(self.config, self.app, core_lifecycle)
|
||||
self.lr = LogRoute(self.config, self.app, core_lifecycle.log_broker)
|
||||
self.sfr = StaticFileRoute(self.config, self.app)
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while not self.core_lifecycle.event_queue.closed:
|
||||
await asyncio.sleep(1)
|
||||
@@ -38,5 +58,8 @@ class AstrBotDashboard():
|
||||
|
||||
def run(self):
|
||||
ip_addr = get_local_ip_addresses()
|
||||
logger.info(f"\n-----\n🌈 管理面板已启动,可访问 \n1. http://{ip_addr}:6185\n2. http://localhost:6185 登录。\n------")
|
||||
logger.info(f"""🌈 管理面板已启动,可访问
|
||||
1. http://{ip_addr}:6185
|
||||
2. http://localhost:6185
|
||||
登录。默认用户名和密码是 astrbot。""")
|
||||
return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder)
|
||||
@@ -148,11 +148,11 @@ class Main:
|
||||
if self.provider_config.prompt_prefix:
|
||||
event.message_str = self.provider_config.prompt_prefix + event.message_str
|
||||
|
||||
image_url = None
|
||||
image_urls = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
break
|
||||
image_urls.append(image_url)
|
||||
|
||||
tool_use_flag = False
|
||||
llm_result = None
|
||||
@@ -210,7 +210,7 @@ class Main:
|
||||
llm_result = await self.provider.text_chat(
|
||||
prompt=event.message_str,
|
||||
session_id=event.session_id,
|
||||
image_url=image_url
|
||||
image_urls=image_urls
|
||||
)
|
||||
await Metric.upload(llm_tick=1, llm_name=self.provider.get_model(), llm_api_base=self.provider.base_url)
|
||||
except BadRequestError as e:
|
||||
@@ -231,7 +231,7 @@ class Main:
|
||||
llm_result = await self.provider.text_chat(
|
||||
prompt=event.message_str,
|
||||
session_id=event.session_id,
|
||||
image_url=image_url
|
||||
image_urls=image_urls
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -50,9 +50,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
将图片转换为 base64
|
||||
'''
|
||||
if image_url.startswith("http"):
|
||||
image_url = await download_image_by_url(image_url)
|
||||
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
@@ -98,12 +95,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
'''
|
||||
组装上下文。
|
||||
'''
|
||||
|
||||
if image_urls:
|
||||
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
|
||||
for image_url in image_urls:
|
||||
base_64_image = await self.encode_image_bs64(image_url)
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": base_64_image}})
|
||||
if image_url.startswith("http"):
|
||||
image_data = image_url
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user","content": text}
|
||||
|
||||
@@ -13,3 +13,4 @@ psutil
|
||||
lxml_html_clean
|
||||
colorlog
|
||||
aiocqhttp
|
||||
pyjwt
|
||||
Reference in New Issue
Block a user