diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 77a7dbeea..a5851adf5 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -44,7 +44,6 @@ class Stage(abc.ABC): try: ready_to_call = handler(event, **params) except TypeError as e: - print(e) # 向下兼容 ready_to_call = handler(event, ctx.plugin_manager.context, **params) diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 14e48cab4..3b0434518 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -58,7 +58,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): payloads = { "messages": context_query, **model_cfgs - } + } llm_response = None try: llm_response = await self._query(payloads, func_tool) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index b4f28cc07..a6fe88384 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -140,6 +140,11 @@ class PluginManager: def reload(self): '''扫描并加载所有的 Star''' + for smd in star_registry: + logger.debug(f"尝试终止插件 {smd.name} ...") + if hasattr(smd.star_cls, "__del__"): + smd.star_cls.__del__() + star_handlers_registry.clear() star_handlers_registry.star_handlers_map.clear() star_map.clear() @@ -272,6 +277,7 @@ class PluginManager: raise Exception("该插件是 AstrBot 保留插件,无法更新。") await self.updator.update(plugin) + self.reload() def install_plugin_from_file(self, zip_file_path: str): desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path)) diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 93c7aefbd..02b9dc2da 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator): files = os.listdir(os.path.join(target_dir, update_dir)) for f in files: - logger.info(f"移动更新文件/目录: {f}") if os.path.isdir(os.path.join(target_dir, update_dir, f)): if os.path.exists(os.path.join(target_dir, f)): shutil.rmtree(os.path.join(target_dir, f), onerror=on_error) @@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator): shutil.move(os.path.join(target_dir, update_dir, f), target_dir) try: - logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") + logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}") shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error) os.remove(zip_path) except BaseException: diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index ed3657531..ade94fa6e 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -111,7 +111,7 @@ class RepoZipUpdator(): releases = await self.fetch_release_info(url=release_url) if not releases: # download from the default branch directly. - logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。") + logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。") release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip" else: release_url = releases[0]['zipball_url'] diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index b18afa3b4..c43ab7660 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -53,7 +53,6 @@ class PluginRoute(Route): try: logger.info(f"正在安装插件 {repo_url}") await self.plugin_manager.install_plugin(repo_url) - self.core_lifecycle.restart() logger.info(f"安装插件 {repo_url} 成功。") return Response().ok(None, "安装成功。").__dict__ except Exception as e: @@ -69,7 +68,6 @@ class PluginRoute(Route): await file.save(file_path) self.plugin_manager.install_plugin_from_file(file_path) logger.info(f"安装插件 {file.filename} 成功") - self.core_lifecycle.restart() return Response().ok(None, "安装成功。").__dict__ except Exception as e: logger.error(traceback.format_exc()) @@ -93,9 +91,8 @@ class PluginRoute(Route): try: logger.info(f"正在更新插件 {plugin_name}") await self.plugin_manager.update_plugin(plugin_name) - self.core_lifecycle.restart() - logger.info(f"更新插件 {plugin_name} 成功,2秒后重启") - return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__ + logger.info(f"更新插件 {plugin_name} 成功。") + return Response().ok(None, "更新成功。").__dict__ except Exception as e: logger.error(f"/api/extensions/update: {traceback.format_exc()}") return Response().error(str(e)).__dict__ \ No newline at end of file diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py index 8fc6fcd53..03f241de4 100644 --- a/astrbot/dashboard/routes/update.py +++ b/astrbot/dashboard/routes/update.py @@ -32,6 +32,7 @@ class UpdateRoute(Route): async def update_project(self): data = await request.json version = data.get('version', '') + reboot = data.get('reboot', True) if version == "" or version == "latest": latest = True version = '' @@ -39,8 +40,11 @@ class UpdateRoute(Route): latest = False try: await self.astrbot_updator.update(latest=latest, version=version) - threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start() - return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__ + if reboot: + threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start() + return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__ + else: + return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__ except Exception as e: logger.error(f"/api/update_project: {traceback.format_exc()}") return Response().error(e.__str__()).__dict__ \ No newline at end of file diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py new file mode 100644 index 000000000..970ee2749 --- /dev/null +++ b/tests/test_dashboard.py @@ -0,0 +1,149 @@ +import pytest +import os +from quart import Quart +from astrbot.dashboard.server import AstrBotDashboard +from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core import LogBroker +from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.star.star import star_registry +from astrbot.core.updator import AstrBotUpdator + + +@pytest.fixture(scope="module") +def core_lifecycle(): + db = SQLiteDatabase("data/data_v3.db") + log_broker = LogBroker() + core_lifecycle = AstrBotCoreLifecycle(log_broker, db) + return core_lifecycle + +@pytest.fixture(scope="module") +def app(core_lifecycle): + db = SQLiteDatabase("data/data_v3.db") + server = AstrBotDashboard(core_lifecycle, db) + return server.app + +@pytest.fixture(scope="module") +def header(): + return {} + +@pytest.mark.asyncio +async def test_init_core_lifecycle(core_lifecycle): + await core_lifecycle.initialize() + assert core_lifecycle is not None + +@pytest.mark.asyncio +async def test_auth_login(app: Quart, core_lifecycle: AstrBotCoreLifecycle, header: dict): + test_client = app.test_client() + response = await test_client.post('/api/auth/login', json={ + "username": "wrong", + "password": "password" + }) + data = await response.get_json() + assert data['status'] == 'error' + + response = await test_client.post('/api/auth/login', json={ + "username": core_lifecycle.astrbot_config['dashboard']['username'], + "password": core_lifecycle.astrbot_config['dashboard']['password'] + }) + data = await response.get_json() + assert data['status'] == 'ok' and 'token' in data['data'] + header['Authorization'] = f"Bearer {data['data']['token']}" + +@pytest.mark.asyncio +async def test_get_stat(app: Quart, header: dict): + test_client = app.test_client() + response = await test_client.get('/api/stat/get') + assert response.status_code == 401 + response = await test_client.get('/api/stat/get', headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' and 'platform' in data['data'] + +@pytest.mark.asyncio +async def test_plugins(app: Quart, header: dict): + test_client = app.test_client() + # 已经安装的插件 + response = await test_client.get('/api/plugin/get', headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' + + # 插件市场 + response = await test_client.get('/api/plugin/market_list', headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' + + # 插件安装 + response = await test_client.post('/api/plugin/install', json={ + "url": "https://github.com/Soulter/astrbot_plugin_essential" + }, headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' + exists = False + for md in star_registry: + if md.name == "astrbot_plugin_essential": + exists = True + break + assert exists is True, "插件 astrbot_plugin_essential 未成功载入" + + # 插件更新 + response = await test_client.post('/api/plugin/update', json={ + "name": "astrbot_plugin_essential" + }, headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' + + # 插件卸载 + response = await test_client.post('/api/plugin/uninstall', json={ + "name": "astrbot_plugin_essential" + }, headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' + exists = False + for md in star_registry: + if md.name == "astrbot_plugin_essential": + exists = True + break + assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + exists = False + for md in star_handlers_registry: + if "astrbot_plugin_essential" in md.handler_module_path: + exists = True + break + assert exists is False, "插件 astrbot_plugin_essential 未成功卸载" + +@pytest.mark.asyncio +async def test_check_update(app: Quart, header: dict): + test_client = app.test_client() + response = await test_client.get('/api/update/check', headers=header) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'success' + +@pytest.mark.asyncio +async def test_do_update(app: Quart, header: dict, core_lifecycle: AstrBotCoreLifecycle): + global VERSION + test_client = app.test_client() + os.makedirs("data/astrbot_release", exist_ok=True) + core_lifecycle.astrbot_updator.MAIN_PATH = "data/astrbot_release" + VERSION = "114.514.1919810" + response = await test_client.post('/api/update/do', headers=header, json={ + "version": "latest" + }) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'error' # 已经是最新版本 + + response = await test_client.post('/api/update/do', headers=header, json={ + "version": "v3.4.0", + "reboot": False + }) + assert response.status_code == 200 + data = await response.get_json() + assert data['status'] == 'ok' + assert os.path.exists("data/astrbot_release/astrbot") \ No newline at end of file diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a142d79ad..12decef50 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,4 +1,6 @@ -import pytest, logging, os +import pytest +import logging +import os from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext from astrbot.core.star import PluginManager from astrbot.core.config.astrbot_config import AstrBotConfig @@ -11,7 +13,6 @@ from astrbot.core.platform.manager import PlatformManager from astrbot.core.provider.manager import ProviderManager from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.star.context import Context -from astrbot.core import logger from asyncio import Queue SESSION_ID_IN_WHITELIST = "test_sid_wl" @@ -29,7 +30,7 @@ TEST_LLM_PROVIDER = { TEST_COMMANDS = [ ["help", "已注册的 AstrBot 内置指令"], - ["tool ls", "查看、激活、停用当前注册的函数工具"], + ["tool ls", "函数工具"], ["tool on websearch", "激活工具"], ["tool off websearch", "停用工具"], ["plugin", "已加载的插件"], @@ -145,6 +146,7 @@ async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineSch async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog): '''测试唤醒''' # 群聊无 @ 无指令 + caplog.clear() mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) @@ -161,19 +163,21 @@ async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog): @pytest.mark.asyncio async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog): + caplog.clear() + mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123") + with caplog.at_level(logging.INFO): + await pipeline_scheduler.execute(mock_event) + assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息" + mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123") with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息" - mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123") - await pipeline_scheduler.execute(mock_event) - assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息" - - @pytest.mark.asyncio async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog): # 测试默认屏蔽词 + caplog.clear() mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。 with caplog.at_level(logging.INFO): await pipeline_scheduler.execute(mock_event) @@ -192,6 +196,7 @@ async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, ca @pytest.mark.asyncio async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog): + caplog.clear() mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) @@ -201,6 +206,7 @@ async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog): @pytest.mark.asyncio async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog): + caplog.clear() mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) @@ -210,6 +216,7 @@ async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog) @pytest.mark.asyncio async def test_commands(pipeline_scheduler: PipelineScheduler, caplog): for command in TEST_COMMANDS: + caplog.clear() mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST) with caplog.at_level(logging.DEBUG): await pipeline_scheduler.execute(mock_event) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 8056a1846..a77050070 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,6 +1,5 @@ import pytest import os -import shutil from astrbot.core.star.star_manager import PluginManager from astrbot.core.star.star_handler import star_handlers_registry from astrbot.core.star.star import star_registry @@ -34,7 +33,8 @@ def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None -def test_plugin_manager_reload(plugin_manager_pm: PluginManager): +@pytest.mark.asyncio +async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): success, err_message = plugin_manager_pm.reload() assert success is True assert err_message is None @@ -86,8 +86,3 @@ async def test_plugin_crud(plugin_manager_pm: PluginManager): await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha") # TODO: file installation - - - - - \ No newline at end of file