diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index a1c61bb7c..bad5b37c8 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -5,4 +5,5 @@ from astrbot.core.config.default import DB_PATH html_renderer = HtmlRenderer() logger = LogManager.GetLogger(log_name='astrbot') -db_helper = SQLiteDatabase(DB_PATH) \ No newline at end of file +db_helper = SQLiteDatabase(DB_PATH) +WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" \ No newline at end of file diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 3e8b572dd..85e3a6252 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -142,6 +142,7 @@ class ProjectATRI: long_term_memory: ATRILongTermMemory = field(default_factory=ATRILongTermMemory) active_message: ATRIActiveMessage = field(default_factory=ATRIActiveMessage) persona: str = "" + split_response: bool = True embedding_provider_id: str = "" summarize_provider_id: str = "" chat_provider_id: str = "" diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index aeb7312f3..c47df8d9e 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -3,7 +3,7 @@ ''' VERSION = '3.4.0' -DB_PATH = 'data/data_v2.db' +DB_PATH = 'data/data_v3.db' # LLM 提供商配置模板 PROVIDER_CONFIG_TEMPLATE = { @@ -144,8 +144,8 @@ DEFAULT_CONFIG_VERSION_2 = { "http_proxy": "", "dashboard": { "enable": True, - "username": "", - "password": "", + "username": "astrbot", + "password": "77b90590a8945a7d36c963981a307dc9", }, "log_level": "INFO", "t2i_endpoint": "", @@ -160,7 +160,12 @@ DEFAULT_CONFIG_VERSION_2 = { "active_message": { "enable": False, }, + "vision": { + "enable": False, + "provider_id_or_ofa_model_path": "", + }, "persona": "", + "split_response": True, "embedding_provider_id": "", "summarize_provider_id": "", "chat_provider_id": "", @@ -188,7 +193,7 @@ CONFIG_METADATA_2 = { "ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"}, "qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"}, "qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"}, - "wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"}, + "wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "obvious_hint": True, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。注意:每次扫码登录之后,相同联系人的 ID 会发生变化,白名单内的 ID 会失效。"}, } }, "platform_settings": { @@ -322,6 +327,15 @@ CONFIG_METADATA_2 = { "enable": {"description": "启用", "type": "bool"}, } }, + "vision": { + "description": "视觉理解", + "type": "object", + "items": { + "enable": {"description": "启用", "type": "bool"}, + "provider_id_or_ofa_model_path": {"description": "提供商 ID 或 OFA 模型路径", "type": "string", "hint": "将会使用指定的 provider 来进行视觉处理,请确保所填的 provider id 在 `配置页` 中存在。"}, + } + }, + "split_response": {"description": "是否分割回复", "type": "bool", "hint": "启用后,将会根据句子分割回复以更像人类回复。每次回复之间具有随机的事件间隔。默认启用。"}, "persona": {"description": "人格", "type": "string", "hint": "默认人格。当启动 ATRI 之后,在 Provider 处设置的人格将会失效。", "obvious_hint": True}, "embedding_provider_id": {"description": "Embedding provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Embedding,请确保所填的 provider id 在 `配置页` 中存在并且设置了 Embedding 配置", "obvious_hint": True}, "summarize_provider_id": {"description": "Summary provider ID", "type": "string", "hint": "只有当启用了长期记忆时,才需要填写此项。将会使用指定的 provider 来获取 Summary,请确保所填的 provider id 在 `配置页` 中存在。", "obvious_hint": True}, diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index b8e8afb8a..b34e14c1a 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -417,6 +417,7 @@ class Unknown(BaseMessageComponent): ComponentTypes = { "plain": Plain, + "text": Plain, "face": Face, "record": Record, "video": Video, diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index bf8907176..bd36cae75 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -2,6 +2,7 @@ import os import ssl import shutil import socket +import ifaddr import time import aiohttp import base64 diff --git a/astrbot/dashboard/routes/auth.py b/astrbot/dashboard/routes/auth.py index b76a9eeaa..b54b3aebc 100644 --- a/astrbot/dashboard/routes/auth.py +++ b/astrbot/dashboard/routes/auth.py @@ -1,13 +1,14 @@ -from .route import Route, Response +import jwt, datetime +from .route import Route, Response, RouteContext from quart import Quart, request -from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core import WEBUI_SK class AuthRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext) -> None: + super().__init__(context) self.routes = { '/auth/login': ('POST', self.login), - '/auth/password/reset': ('POST', self.reset_password), + '/auth/account/edit': ('POST', self.edit_account), } self.register_routes() @@ -17,17 +18,37 @@ class AuthRoute(Route): post_data = await request.json if post_data["username"] == username and post_data["password"] == password: return Response().ok({ - "token": "astrbot-test-token", + "token": self.generate_jwt(username), "username": username }).__dict__ else: return Response().error("用户名或密码错误").__dict__ - async def reset_password(self): + async def edit_account(self): password = self.config.dashboard.password post_data = await request.json - if post_data["password"] == password: - self.config.dashboard.password = post_data['new_password'] - return Response().ok(None).__dict__ - else: - return Response().error("原密码错误").__dict__ \ No newline at end of file + + if post_data["password"] != password: + return Response().error("原密码错误").__dict__ + + new_pwd = post_data.get('new_password', None) + new_username = post_data.get('new_username', None) + if not new_pwd and not new_username: + return Response().error("新用户名和新密码不能同时为空,你改了个寂寞").__dict__ + + if new_pwd: + self.config.dashboard.password = new_pwd + if new_username: + self.config.dashboard.username = new_username + + self.config.flush_config() + + return Response().ok(None, "修改成功").__dict__ + + def generate_jwt(self, username): + payload = { + "username": username, + "exp": datetime.datetime.utcnow() + datetime.timedelta(days=30) + } + token = jwt.encode(payload, WEBUI_SK, algorithm="HS256") + return token \ No newline at end of file diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 55f3379c1..2a0508a27 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,5 +1,5 @@ import os, json -from .route import Route, Response +from .route import Route, Response, RouteContext from quart import Quart, request from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP, PROVIDER_CONFIG_TEMPLATE from astrbot.core.config.astrbot_config import AstrBotConfig @@ -87,8 +87,8 @@ def save_extension_config(post_config: dict): update_config(namespace, key, value) class ConfigRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None: + super().__init__(context) self.config_key_dont_show = ['dashboard', 'config_version'] self.core_lifecycle = core_lifecycle self.routes = { diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index 8dabccf6f..f05004ccc 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -1,13 +1,11 @@ import asyncio from quart import websocket -from quart import Quart -from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core import logger, LogBroker -from .route import Route, Response +from .route import Route, Response, RouteContext class LogRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart, log_broker: LogBroker) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext, log_broker: LogBroker) -> None: + super().__init__(context) self.log_broker = log_broker self.app.add_url_rule('/api/live-log', view_func=self.log, methods=['GET'], websocket=True) diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index 1a112ff2a..3977d906b 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,14 +1,13 @@ import threading, traceback, uuid -from .route import Route, Response +from .route import Route, Response, RouteContext from astrbot.core import logger from quart import Quart, request -from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.plugin.plugin_manager import PluginManager from astrbot.core.core_lifecycle import AstrBotCoreLifecycle class PluginRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle, plugin_manager: PluginManager) -> None: + super().__init__(context) self.routes = { '/plugin/get': ('GET', self.get_plugins), '/plugin/install': ('POST', self.install_plugin), diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 5b2ff41c2..cbec57cdf 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -2,11 +2,15 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from dataclasses import dataclass from quart import Quart +@dataclass +class RouteContext: + config: AstrBotConfig + app: Quart class Route(): - def __init__(self, config: AstrBotConfig, app: Quart): - self.app = app - self.config = config + def __init__(self, context: RouteContext): + self.app = context.app + self.config = context.config def register_routes(self): for route, (method, func) in self.routes.items(): diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index d69e474d1..d0b4d62d5 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -1,5 +1,5 @@ import traceback, psutil, time, aiohttp -from .route import Route, Response +from .route import Route, Response, RouteContext from astrbot.core import logger from quart import Quart, request from astrbot.core.config.astrbot_config import AstrBotConfig @@ -8,8 +8,8 @@ from astrbot.core.db import BaseDatabase from astrbot.core.config import VERSION class StatRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext, db_helper: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None: + super().__init__(context) self.routes = { '/stat/get': ('GET', self.get_stat), '/stat/version': ('GET', self.get_version), diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 4e8b835c2..3d12cdedd 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -1,12 +1,9 @@ -from .route import Route -from quart import Quart -from astrbot.core.config.astrbot_config import AstrBotConfig - +from .route import Route, RouteContext class StaticFileRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext) -> None: + super().__init__(context) - index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default'] + index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console'] for i in index_: self.app.add_url_rule(i, view_func=self.index) diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index a4566b190..52300a36f 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -1,13 +1,12 @@ import threading, traceback -from .route import Route, Response +from .route import Route, Response, RouteContext from quart import Quart, request -from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.updator import AstrBotUpdator from astrbot.core import logger class UpdateRoute(Route): - def __init__(self, config: AstrBotConfig, app: Quart, astrbot_updator: AstrBotUpdator) -> None: - super().__init__(config, app) + def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None: + super().__init__(context) self.routes = { '/update/check': ('GET', self.check_update), '/update/do': ('POST', self.update_project), diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 32d5713f6..9f5c772ec 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -1,36 +1,56 @@ -import logging +import logging, jwt import asyncio, os -from quart import Quart +from quart import Quart, request from quart.logging import default_handler from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .routes import * -from astrbot.core import logger +from .routes.route import RouteContext, Response +from astrbot.core import logger, WEBUI_SK from astrbot.core.db import BaseDatabase -from astrbot.core.plugin.plugin_manager import PluginManager -from astrbot.core.updator import AstrBotUpdator from astrbot.core.utils.io import get_local_ip_addresses -from astrbot.core.config import AstrBotConfig from astrbot.core.db import BaseDatabase +DATAPATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data")) + class AstrBotDashboard(): def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config - self.data_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../data/dist")) + self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist")) logger.info(f"Dashboard data path: {self.data_path}") self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") self.app.json.sort_keys = False - + self.app.before_request(self.auth_middleware) + # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) + self.context = RouteContext(self.config, self.app) + self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator) + self.sr = StatRoute(self.context, db, core_lifecycle) + self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager) + self.cr = ConfigRoute(self.context, core_lifecycle) + self.lr = LogRoute(self.context, core_lifecycle.log_broker) + self.sfr = StaticFileRoute(self.context) + self.ar = AuthRoute(self.context) + + async def auth_middleware(self): + if not request.path.startswith("/api"): + return + if request.path == "/api/auth/login": + return + # claim jwt + token = request.headers.get("Authorization") + if token.startswith("Bearer "): + token = token[7:] + if not token: + return Response().error("未授权").__dict__ + try: + jwt.decode(token, WEBUI_SK, algorithms=["HS256"]) + except jwt.ExpiredSignatureError: + return Response().error("Token 过期").__dict__ + except jwt.InvalidTokenError: + return Response().error("Token 无效").__dict__ + - self.ar = AuthRoute(self.config, self.app) - self.ur = UpdateRoute(self.config, self.app, core_lifecycle.astrbot_updator) - self.sr = StatRoute(self.config, self.app, db, core_lifecycle) - self.pr = PluginRoute(self.config, self.app, core_lifecycle, core_lifecycle.plugin_manager) - self.cr = ConfigRoute(self.config, self.app, core_lifecycle) - self.lr = LogRoute(self.config, self.app, core_lifecycle.log_broker) - self.sfr = StaticFileRoute(self.config, self.app) - async def shutdown_trigger_placeholder(self): while not self.core_lifecycle.event_queue.closed: await asyncio.sleep(1) @@ -38,5 +58,8 @@ class AstrBotDashboard(): def run(self): ip_addr = get_local_ip_addresses() - logger.info(f"\n-----\n🌈 管理面板已启动,可访问 \n1. http://{ip_addr}:6185\n2. http://localhost:6185 登录。\n------") + logger.info(f"""🌈 管理面板已启动,可访问 +1. http://{ip_addr}:6185 +2. http://localhost:6185 +登录。默认用户名和密码是 astrbot。""") return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder) \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/main.py b/packages/astrbot_plugin_openai/main.py index dd1445a88..990b534da 100644 --- a/packages/astrbot_plugin_openai/main.py +++ b/packages/astrbot_plugin_openai/main.py @@ -148,11 +148,11 @@ class Main: if self.provider_config.prompt_prefix: event.message_str = self.provider_config.prompt_prefix + event.message_str - image_url = None + image_urls = [] for comp in event.message_obj.message: if isinstance(comp, Image): image_url = comp.url if comp.url else comp.file - break + image_urls.append(image_url) tool_use_flag = False llm_result = None @@ -210,7 +210,7 @@ class Main: llm_result = await self.provider.text_chat( prompt=event.message_str, session_id=event.session_id, - image_url=image_url + image_urls=image_urls ) await Metric.upload(llm_tick=1, llm_name=self.provider.get_model(), llm_api_base=self.provider.base_url) except BadRequestError as e: @@ -231,7 +231,7 @@ class Main: llm_result = await self.provider.text_chat( prompt=event.message_str, session_id=event.session_id, - image_url=image_url + image_urls=image_urls ) except BaseException as e: logger.error(traceback.format_exc()) diff --git a/packages/astrbot_plugin_openai/openai_adapter.py b/packages/astrbot_plugin_openai/openai_adapter.py index a1da2349c..7fe8cfa8d 100644 --- a/packages/astrbot_plugin_openai/openai_adapter.py +++ b/packages/astrbot_plugin_openai/openai_adapter.py @@ -50,9 +50,6 @@ class ProviderOpenAIOfficial(Provider): ''' 将图片转换为 base64 ''' - if image_url.startswith("http"): - image_url = await download_image_by_url(image_url) - with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode('utf-8') return "data:image/jpeg;base64," + image_bs64 @@ -98,12 +95,14 @@ class ProviderOpenAIOfficial(Provider): ''' 组装上下文。 ''' - if image_urls: user_content = {"role": "user","content": [{"type": "text", "text": text}]} for image_url in image_urls: - base_64_image = await self.encode_image_bs64(image_url) - user_content["content"].append({"type": "image_url", "image_url": {"url": base_64_image}}) + if image_url.startswith("http"): + image_data = image_url + else: + image_data = await self.encode_image_bs64(image_url) + user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) return user_content else: return {"role": "user","content": text} diff --git a/requirements.txt b/requirements.txt index e301c3904..1d5bfe9ae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ psutil lxml_html_clean colorlog aiocqhttp +pyjwt \ No newline at end of file