Merge branch 'master' into feat-dify

This commit is contained in:
Soulter
2025-01-07 19:59:54 +08:00
32 changed files with 915 additions and 56 deletions
+15 -9
View File
@@ -1,7 +1,14 @@
name: Run tests and upload coverage
on:
push
push:
branches:
- master
paths-ignore:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
workflow_dispatch:
jobs:
test:
@@ -21,17 +28,16 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
mkdir data
mkdir data/plugins
mkdir data/config
mkdir temp
- name: Run tests
run: |
export LLM_MODEL=${{ secrets.LLM_MODEL }}
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
+1
View File
@@ -14,6 +14,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
</a>
<a href="https://astrbot.soulter.top/">查看文档</a>
+4
View File
@@ -8,5 +8,9 @@ os.makedirs("data", exist_ok=True)
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+8 -2
View File
@@ -17,6 +17,7 @@ DEFAULT_CONFIG = {
},
"reply_prefix": "",
"forward_threshold": 200,
"enable_id_white_list": True,
"id_whitelist": [],
"id_whitelist_log": True,
"wl_ignore_admin_on_group": True,
@@ -49,7 +50,8 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": ""
"plugin_repo_mirror": "",
"knowledge_db": {},
}
@@ -162,6 +164,10 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。",
},
"enable_id_white_list": {
"description": "启用 ID 白名单",
"type": "bool"
},
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
@@ -273,7 +279,7 @@ CONFIG_METADATA_2 = {
},
"zhipu": {
"id": "zhipu_default",
"type": "openai_chat_completion",
"type": "zhipu_chat_completion",
"enable": True,
"key": [],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
+15 -4
View File
@@ -16,6 +16,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
@@ -29,7 +30,10 @@ class AstrBotCoreLifecycle:
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
logger.setLevel(self.astrbot_config['log_level'])
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
@@ -37,9 +41,16 @@ class AstrBotCoreLifecycle:
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
self.star_context.platform_manager = self.platform_manager
self.star_context.provider_manager = self.provider_manager
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
@@ -22,6 +22,9 @@ class LLMRequestSubStage(Stage):
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
@@ -27,7 +27,7 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_str not in star_map:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
@@ -39,7 +39,7 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
-1
View File
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
try:
ready_to_call = handler(event, **params)
except TypeError as e:
print(e)
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
@@ -10,12 +10,16 @@ class WhitelistCheckStage(Stage):
'''检查是否在群聊/私聊白名单
'''
async def initialize(self, ctx: PipelineContext) -> None:
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
return
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -11,7 +11,7 @@ from botpy import Client
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .qqofficial_message_event import QQOfficialMessageEvent
from ...register import register_platform_adapter
@@ -111,6 +111,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.message_id = message.id
abm.tag = "qq_official"
msg: List[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
if isinstance(message, botpy.message.GroupMessage):
@@ -126,7 +127,7 @@ class QQOfficialPlatformAdapter(Platform):
)
abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid"
msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str))
if message.attachments:
for i in message.attachments:
@@ -146,7 +147,7 @@ class QQOfficialPlatformAdapter(Platform):
plain_content = message.content.replace(
"<@!"+str(abm.self_id)+">", "").strip()
msg.append(Plain(plain_content))
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
@@ -161,11 +162,14 @@ class QQOfficialPlatformAdapter(Platform):
str(message.author.id),
str(message.author.username)
)
msg.append(At(qq="qq_official"))
msg.append(Plain(plain_content))
if isinstance(message, botpy.message.Message):
abm.group_id = message.channel_id
else:
raise ValueError(f"Unknown message type: {message_type}")
abm.self_id = "qq_official"
return abm
def run(self):
+9
View File
@@ -18,6 +18,11 @@ class ProviderManager():
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
@@ -29,6 +34,8 @@ class ProviderManager():
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
@@ -54,6 +61,8 @@ class ProviderManager():
if len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
else:
logger.warning("未启用任何大模型提供商适配器。")
def get_insts(self):
return self.provider_insts
@@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
@@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider):
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
return llm_response
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
@@ -0,0 +1,73 @@
import traceback
from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
class ProviderZhipu(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
new_context_query_ = []
for i in range(0, len(context_query) - 1, 2):
if isinstance(context_query[i].get("content", ""), list):
continue
new_context_query_.append(context_query[i])
new_context_query_.append(context_query[i+1])
new_context_query_.append(context_query[-1]) # 保留最后一条记录
context_query = new_context_query_
logger.debug(context_query)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**model_cfgs
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
@@ -0,0 +1,25 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding():
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(
api_key=api_key,
base_url=api_base
)
self.model = model
async def get_embedding(self, text) -> List[float]:
'''
获取文本的嵌入
'''
embedding = await self.client.embeddings.create(
input=text,
model=self.model
)
return embedding.data[0].embedding
+92
View File
@@ -0,0 +1,92 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
class KnowledgeDBManager():
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = "data/knowledge_db/"
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
async def create_knowledge_db(self, name: str, config: Dict):
'''
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
'''
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]['strategy']:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks
+8
View File
@@ -0,0 +1,8 @@
from typing import List
class Store():
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass
+39
View File
@@ -0,0 +1,39 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None)
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding
)
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding,
n_results=top_n,
where=metadata_filter
)
return results['documents'][0]
+14 -3
View File
@@ -14,6 +14,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class StarCommand(TypedDict):
full_command_name: str
@@ -39,10 +40,20 @@ class Context:
# back compatibility
_register_tasks: List[Awaitable] = []
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
def __init__(self,
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
for star in star_registry:
@@ -73,7 +84,7 @@ class Context:
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_str=func_obj.__module__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
@@ -125,7 +136,7 @@ class Context:
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
+3
View File
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
return False
# if len(self.handler_params) == 0 and len(ls) > 1:
# # 一定程度避免 LLM 聊天时误判为指令
# return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 去除空字符串
+1 -1
View File
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[]
)
+5 -6
View File
@@ -1,11 +1,11 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Awaitable, List, Dict
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
class StarHandlerRegistry(List):
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T], List[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -26,8 +26,7 @@ class StarHandlerRegistry(List):
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_str == module_name]
return [handler for handler in self if handler.handler_module_path == module_name]
star_handlers_registry = StarHandlerRegistry()
@@ -55,7 +54,7 @@ class StarHandlerMetadata():
handler_name: str
'''Handler 的名字,也就是方法名'''
handler_module_str: str
handler_module_path: str
'''Handler 所在的模块路径。'''
handler: Awaitable
+31 -4
View File
@@ -1,6 +1,7 @@
import inspect
import functools
import os
import sys
import traceback
import yaml
import logging
@@ -14,9 +15,8 @@ from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from astrbot.core.provider.register import llm_tools
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
class PluginManager:
def __init__(
@@ -138,7 +138,18 @@ class PluginManager:
def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
@@ -225,10 +236,11 @@ class PluginManager:
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
self._check_plugin_dept_update()
# reload the plugin
self.reload()
return plugin_path
def uninstall_plugin(self, plugin_name: str):
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -237,7 +249,20 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
# 从 star_registry 和 star_map 中删除
del star_map[plugin.module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(plugin.module_path):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
star_handlers_registry.remove(handler)
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin.module_path)]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
del star_handlers_registry.star_handlers_map[k]
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
@@ -250,6 +275,7 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
self.reload()
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
@@ -262,3 +288,4 @@ class PluginManager:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
+1 -2
View File
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
logger.info(f"删除临时文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
+1 -1
View File
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,从默认分支下载。")
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+3 -6
View File
@@ -53,7 +53,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -69,7 +68,6 @@ class PluginRoute(Route):
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
self.core_lifecycle.restart()
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
@@ -80,7 +78,7 @@ class PluginRoute(Route):
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
self.plugin_manager.uninstall_plugin(plugin_name)
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
@@ -93,9 +91,8 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+6 -2
View File
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
async def update_project(self):
data = await request.json
version = data.get('version', '')
reboot = data.get('reboot', True)
if version == "" or version == "latest":
latest = True
version = ''
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
else:
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
+6 -5
View File
@@ -1,4 +1,3 @@
import os
import asyncio
import sys
@@ -42,14 +41,16 @@ async def check_dashboard_files():
return
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
logger.info("开始下载管理面板文件...")
ok = False
async with aiohttp.ClientSession() as session:
async with session.get(dashboard_release_url) as resp:
if resp.status != 200:
logger.error(f"下载管理面板文件失败: {resp.status}")
with open("data/dashboard.zip", "wb") as f:
f.write(await resp.read())
logger.info("管理面板文件下载完成。")
ok = True
else:
with open("data/dashboard.zip", "wb") as f:
f.write(await resp.read())
logger.info("管理面板文件下载完成。")
ok = True
if not ok:
logger.critical("下载管理面板文件失败")
+33 -2
View File
@@ -16,6 +16,8 @@ class Main(star.Star):
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.kdb_enabled = False
async def _query_astrbot_notice(self):
try:
@@ -289,7 +291,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
"""))
""").use_t2i(False))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
@@ -337,4 +339,33 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()
event.stop_event()
@filter.command_group("kdb")
def kdb(self):
pass
@kdb.command("on")
async def on_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = True
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if not curr_kdb_name:
yield event.plain_result("未载入任何知识库")
else:
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
@kdb.command("off")
async def off_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = False
yield event.plain_result("知识库已关闭")
@filter.on_llm_request()
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if self.kdb_enabled and curr_kdb_name:
mgr = self.context.knowledge_db_manager
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
if results:
req.system_prompt += "\nHere are documents that related to user's query: \n"
for result in results:
req.system_prompt += f"- {result}\n"
+148
View File
@@ -0,0 +1,148 @@
import pytest
import os
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core import LogBroker
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest.mark.asyncio
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict):
test_client = app.test_client()
response = await test_client.post('/api/auth/login', json={
"username": "wrong",
"password": "password"
})
data = await response.get_json()
assert data['status'] == 'error'
response = await test_client.post('/api/auth/login', json={
"username": core_lifecycle_td.astrbot_config['dashboard']['username'],
"password": core_lifecycle_td.astrbot_config['dashboard']['password']
})
data = await response.get_json()
assert data['status'] == 'ok' and 'token' in data['data']
header['Authorization'] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/stat/get')
assert response.status_code == 401
response = await test_client.get('/api/stat/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok' and 'platform' in data['data']
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get('/api/plugin/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件市场
response = await test_client.get('/api/plugin/market_list', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件安装
response = await test_client.post('/api/plugin/install', json={
"url": "https://github.com/Soulter/astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# 插件更新
response = await test_client.post('/api/plugin/update', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件卸载
response = await test_client.post('/api/plugin/uninstall', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/update/check', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'success'
@pytest.mark.asyncio
async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post('/api/update/do', headers=header, json={
"version": "latest"
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'error' # 已经是最新版本
response = await test_client.post('/api/update/do', headers=header, json={
"version": "v3.4.0",
"reboot": False
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
assert os.path.exists("data/astrbot_release/astrbot")
+48
View File
@@ -0,0 +1,48 @@
import os
import sys
import pytest
from unittest import mock
from main import check_env, check_dashboard_files
class _version_info():
def __init__(self, major, minor):
self.major = major
self.minor = minor
def test_check_env(monkeypatch):
version_info_correct = _version_info(3, 10)
version_info_wrong = _version_info(3, 9)
monkeypatch.setattr(sys, 'version_info', version_info_correct)
with mock.patch('os.makedirs') as mock_makedirs:
check_env()
mock_makedirs.assert_any_call("data/config", exist_ok=True)
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
monkeypatch.setattr(sys, 'version_info', version_info_wrong)
with pytest.raises(SystemExit):
check_env()
@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
monkeypatch.setattr(os.path, 'exists', lambda x: False)
async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
async def read(self):
return b'content'
return MockResponse()
with mock.patch('aiohttp.ClientSession.get', new=mock_get):
with mock.patch('builtins.open', mock.mock_open()) as mock_file:
with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()
async def mock_aexit(obj, exc_type, exc, tb):
return
mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
+224
View File
@@ -0,0 +1,224 @@
import pytest
import logging
import os
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.components import Plain, At
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
TEST_LLM_PROVIDER = {
"id": "zhipu_default",
"type": "openai_chat_completion",
"enable": True,
"key": [os.getenv("ZHIPU_API_KEY")],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
}
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
["t2i", "文本转图片模式"],
["sid", "此 ID 可用于设置会话白名单。"],
["op test_op", "授权成功。"],
["deop test_op", "取消授权成功。"],
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
["provider", "当前载入的 LLM 提供商"],
["reset", "重置成功"],
# ["model", "查看、切换提供商模型列表"],
["history", "历史记录:"],
["key", "当前 Key"],
["persona", "[Persona]"]
]
class FakeAstrMessageEvent(AstrMessageEvent):
def __init__(self, abm: AstrBotMessage = None):
meta = PlatformMetadata("test_platform", "test")
super().__init__(
message_str=abm.message_str,
message_obj=abm,
platform_meta=meta,
session_id=abm.session_id
)
async def send(self, message: MessageChain):
await super().send(message)
@staticmethod
def create_fake_event(
message_str: str,
session_id: str = "test_sid",
is_at: bool = False,
is_group: bool = False,
sender_id: str = "123456"
):
abm = AstrBotMessage()
abm.message_str = message_str
abm.group_id = "test"
abm.message = [Plain(message_str)]
if is_at:
abm.message.append(At(qq="bot"))
abm.self_id = "bot"
abm.sender = MessageMember(sender_id, "mika")
abm.timestamp = 1234567890
abm.message_id = "test"
abm.session_id = session_id
if is_group:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
return FakeAstrMessageEvent(abm)
@pytest.fixture(scope="module")
def event_queue():
return Queue()
@pytest.fixture(scope="module")
def config():
cfg = AstrBotConfig()
cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
cfg['admins_id'] = ["123456"]
cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
cfg['provider'] = [TEST_LLM_PROVIDER]
return cfg
@pytest.fixture(scope="module")
def db():
return SQLiteDatabase("data/data_v3.db")
@pytest.fixture(scope="module")
def platform_manager(event_queue, config):
return PlatformManager(config, event_queue)
@pytest.fixture(scope="module")
def provider_manager(config, db):
return ProviderManager(config, db)
@pytest.fixture(scope="module")
def star_context(event_queue, config, db, platform_manager, provider_manager):
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
plugin_manager.reload()
return plugin_manager
@pytest.fixture(scope="module")
def pipeline_context(config, plugin_manager):
return PipelineContext(config, plugin_manager)
@pytest.fixture(scope="module")
def pipeline_scheduler(pipeline_context):
return PipelineScheduler(pipeline_context)
@pytest.mark.asyncio
async def test_platform_initialization(platform_manager: PlatformManager):
await platform_manager.initialize()
@pytest.mark.asyncio
async def test_provider_initialization(provider_manager: ProviderManager):
await provider_manager.initialize()
@pytest.mark.asyncio
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
await pipeline_scheduler.initialize()
@pytest.mark.asyncio
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
'''测试唤醒'''
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
# 群聊有 @ 无指令
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
# 群聊有指令
mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
await pipeline_scheduler.execute(mock_event)
assert mock_event._has_send_oper is True
@pytest.mark.asyncio
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
# 测试额外屏蔽词
mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" not in message for message in caplog.messages)
# TODO: 测试 百度AI 的内容安全检查
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert mock_event.get_result() is not None
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
assert any(command[1] in message for message in caplog.messages)
+80
View File
@@ -0,0 +1,80 @@
import pytest
import os
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
from astrbot.core.star.context import Context
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from asyncio import Queue
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase("data/data_v3.db")
star_context = Context(event_queue, config, db)
@pytest.fixture
def plugin_manager_pm():
return PluginManager(star_context, config)
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
assert plugin_manager_pm is not None
assert plugin_manager_pm.context is not None
assert plugin_manager_pm.config is not None
@pytest.mark.asyncio
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = plugin_manager_pm.reload()
assert success is True
assert err_message is None
assert len(star_handlers_registry) > 0 # package
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):
'''测试插件安装和重载'''
os.makedirs("data/plugins", exist_ok=True)
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert plugin_path is not None
assert os.path.exists(plugin_path)
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# shutil.rmtree(plugin_path)
# install plugin which is not exists
with pytest.raises(Exception):
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
# update
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
# uninstall
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation