diff --git a/.github/workflows/coverage_test.yml b/.github/workflows/coverage_test.yml
index 175e5f564..30e9237ed 100644
--- a/.github/workflows/coverage_test.yml
+++ b/.github/workflows/coverage_test.yml
@@ -1,7 +1,14 @@
name: Run tests and upload coverage
on:
- push
+ push:
+ branches:
+ - master
+ paths-ignore:
+ - 'README.md'
+ - 'changelogs/**'
+ - 'dashboard/**'
+ workflow_dispatch:
jobs:
test:
@@ -21,17 +28,16 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
- mkdir data
- mkdir data/plugins
- mkdir data/config
- mkdir temp
- name: Run tests
run: |
- export LLM_MODEL=${{ secrets.LLM_MODEL }}
- export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
- export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
- PYTHONPATH=./ pytest --cov=. tests/ -v
+ mkdir data
+ mkdir data/plugins
+ mkdir data/config
+ mkdir data/temp
+ export TESTING=true
+ export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
+ PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
diff --git a/README.md b/README.md
index f35d93404..5a80add94 100644
--- a/README.md
+++ b/README.md
@@ -14,6 +14,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
+[](https://codecov.io/gh/Soulter/AstrBot)
查看文档 |
diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py
index 5efeff30b..94e32224c 100644
--- a/astrbot/core/__init__.py
+++ b/astrbot/core/__init__.py
@@ -8,5 +8,9 @@ os.makedirs("data", exist_ok=True)
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
+
+if os.environ.get('TESTING', ""):
+ logger.setLevel('DEBUG')
+
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
\ No newline at end of file
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 7605473c6..9c78c0220 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -17,6 +17,7 @@ DEFAULT_CONFIG = {
},
"reply_prefix": "",
"forward_threshold": 200,
+ "enable_id_white_list": True,
"id_whitelist": [],
"id_whitelist_log": True,
"wl_ignore_admin_on_group": True,
@@ -49,7 +50,8 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
- "plugin_repo_mirror": ""
+ "plugin_repo_mirror": "",
+ "knowledge_db": {},
}
@@ -162,6 +164,10 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。",
},
+ "enable_id_white_list": {
+ "description": "启用 ID 白名单",
+ "type": "bool"
+ },
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
@@ -273,7 +279,7 @@ CONFIG_METADATA_2 = {
},
"zhipu": {
"id": "zhipu_default",
- "type": "openai_chat_completion",
+ "type": "zhipu_chat_completion",
"enable": True,
"key": [],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py
index cc5d57a9e..3655e55ee 100644
--- a/astrbot/core/core_lifecycle.py
+++ b/astrbot/core/core_lifecycle.py
@@ -16,6 +16,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
+from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
@@ -29,7 +30,10 @@ class AstrBotCoreLifecycle:
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
- logger.setLevel(self.astrbot_config['log_level'])
+ if os.environ.get("TESTING", ""):
+ logger.setLevel("DEBUG")
+ else:
+ logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
@@ -37,9 +41,16 @@ class AstrBotCoreLifecycle:
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
- self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
- self.star_context.platform_manager = self.platform_manager
- self.star_context.provider_manager = self.provider_manager
+ self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
+
+ self.star_context = Context(
+ self.event_queue,
+ self.astrbot_config,
+ self.db,
+ self.provider_manager,
+ self.platform_manager,
+ self.knowledge_db_manager
+ )
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index 6e1ce6b42..6e2a25960 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -22,6 +22,9 @@ class LLMRequestSubStage(Stage):
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
+ if provider is None:
+ return
+
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py
index 5e7e4e6aa..6863df473 100644
--- a/astrbot/core/pipeline/process_stage/method/star_request.py
+++ b/astrbot/core/pipeline/process_stage/method/star_request.py
@@ -27,7 +27,7 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
- if handler.handler_module_str not in star_map:
+ if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
@@ -39,7 +39,7 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
- ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
+ ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py
index 77a7dbeea..a5851adf5 100644
--- a/astrbot/core/pipeline/stage.py
+++ b/astrbot/core/pipeline/stage.py
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
try:
ready_to_call = handler(event, **params)
except TypeError as e:
- print(e)
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py
index b2b713a6a..6a4e7097e 100644
--- a/astrbot/core/pipeline/whitelist_check/stage.py
+++ b/astrbot/core/pipeline/whitelist_check/stage.py
@@ -10,12 +10,16 @@ class WhitelistCheckStage(Stage):
'''检查是否在群聊/私聊白名单
'''
async def initialize(self, ctx: PipelineContext) -> None:
+ self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
+ if not self.enable_whitelist_check:
+ return
+
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
index 0f73ed5fb..de94ab6a4 100644
--- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
+++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py
@@ -11,7 +11,7 @@ from botpy import Client
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
-from astrbot.api.message_components import Image, Plain
+from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .qqofficial_message_event import QQOfficialMessageEvent
from ...register import register_platform_adapter
@@ -111,6 +111,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.message_id = message.id
abm.tag = "qq_official"
msg: List[BaseMessageComponent] = []
+
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
if isinstance(message, botpy.message.GroupMessage):
@@ -126,7 +127,7 @@ class QQOfficialPlatformAdapter(Platform):
)
abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid"
-
+ msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str))
if message.attachments:
for i in message.attachments:
@@ -146,7 +147,7 @@ class QQOfficialPlatformAdapter(Platform):
plain_content = message.content.replace(
"<@!"+str(abm.self_id)+">", "").strip()
- msg.append(Plain(plain_content))
+
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
@@ -161,11 +162,14 @@ class QQOfficialPlatformAdapter(Platform):
str(message.author.id),
str(message.author.username)
)
+ msg.append(At(qq="qq_official"))
+ msg.append(Plain(plain_content))
if isinstance(message, botpy.message.Message):
abm.group_id = message.channel_id
else:
raise ValueError(f"Unknown message type: {message_type}")
+ abm.self_id = "qq_official"
return abm
def run(self):
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 200cc1bca..20ac8a0d0 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -18,6 +18,11 @@ class ProviderManager():
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
+ self.curr_kdb_name = ""
+ kdb_cfg = config.get("knowledge_db", {})
+ if kdb_cfg and len(kdb_cfg):
+ self.curr_kdb_name = list(kdb_cfg.keys())[0]
+
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
@@ -29,6 +34,8 @@ class ProviderManager():
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
+ case "zhipu_chat_completion":
+ from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
@@ -54,6 +61,8 @@ class ProviderManager():
if len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
+ else:
+ logger.warning("未启用任何大模型提供商适配器。")
def get_insts(self):
return self.provider_insts
\ No newline at end of file
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 1cc792bf3..b0c9f0a58 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
-
+
+ await self.save_history(contexts, new_record, session_id, llm_response)
+
+ return llm_response
+
+ async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
@@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider):
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
- return llm_response
-
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py
new file mode 100644
index 000000000..3b0434518
--- /dev/null
+++ b/astrbot/core/provider/sources/zhipu_source.py
@@ -0,0 +1,73 @@
+import traceback
+from astrbot.core.db import BaseDatabase
+from astrbot import logger
+from astrbot.core.provider.func_tool_manager import FuncCall
+from typing import List
+from ..register import register_provider_adapter
+from astrbot.core.provider.entites import LLMResponse
+from .openai_source import ProviderOpenAIOfficial
+
+@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
+class ProviderZhipu(ProviderOpenAIOfficial):
+ def __init__(
+ self,
+ provider_config: dict,
+ provider_settings: dict,
+ db_helper: BaseDatabase,
+ persistant_history = True
+ ) -> None:
+ super().__init__(provider_config, provider_settings, db_helper, persistant_history)
+
+ async def text_chat(
+ self,
+ prompt: str,
+ session_id: str,
+ image_urls: List[str]=None,
+ func_tool: FuncCall=None,
+ contexts=None,
+ system_prompt=None,
+ **kwargs
+ ) -> LLMResponse:
+ new_record = await self.assemble_context(prompt, image_urls)
+ context_query = []
+
+ if not contexts:
+ context_query = [*self.session_memory[session_id], new_record]
+ else:
+ context_query = [*contexts, new_record]
+
+ model_cfgs: dict = self.provider_config.get("model_config", {})
+ # glm-4v-flash 只支持一张图片
+ model: str = model_cfgs.get("model", "")
+ if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
+ logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
+ logger.debug(context_query)
+ new_context_query_ = []
+ for i in range(0, len(context_query) - 1, 2):
+ if isinstance(context_query[i].get("content", ""), list):
+ continue
+ new_context_query_.append(context_query[i])
+ new_context_query_.append(context_query[i+1])
+ new_context_query_.append(context_query[-1]) # 保留最后一条记录
+ context_query = new_context_query_
+ logger.debug(context_query)
+
+ if system_prompt:
+ context_query.insert(0, {"role": "system", "content": system_prompt})
+
+ payloads = {
+ "messages": context_query,
+ **model_cfgs
+ }
+ llm_response = None
+ try:
+ llm_response = await self._query(payloads, func_tool)
+ except Exception as e:
+ if "maximum context length" in str(e):
+ logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
+ self.pop_record(session_id)
+ logger.warning(traceback.format_exc())
+
+ await self.save_history(contexts, new_record, session_id, llm_response)
+
+ return llm_response
\ No newline at end of file
diff --git a/astrbot/core/rag/embedding/openai_source.py b/astrbot/core/rag/embedding/openai_source.py
new file mode 100644
index 000000000..648de0fda
--- /dev/null
+++ b/astrbot/core/rag/embedding/openai_source.py
@@ -0,0 +1,25 @@
+from typing import List
+from openai import AsyncOpenAI
+
+class SimpleOpenAIEmbedding():
+ def __init__(
+ self,
+ model,
+ api_key,
+ api_base=None,
+ ) -> None:
+ self.client = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base
+ )
+ self.model = model
+
+ async def get_embedding(self, text) -> List[float]:
+ '''
+ 获取文本的嵌入
+ '''
+ embedding = await self.client.embeddings.create(
+ input=text,
+ model=self.model
+ )
+ return embedding.data[0].embedding
diff --git a/astrbot/core/rag/knowledge_db_mgr.py b/astrbot/core/rag/knowledge_db_mgr.py
new file mode 100644
index 000000000..2ee8199b7
--- /dev/null
+++ b/astrbot/core/rag/knowledge_db_mgr.py
@@ -0,0 +1,92 @@
+import os
+from typing import List, Dict
+from astrbot.core import logger
+from .store import Store
+from astrbot.core.config import AstrBotConfig
+
+class KnowledgeDBManager():
+ def __init__(self, astrbot_config: AstrBotConfig) -> None:
+ self.db_path = "data/knowledge_db/"
+ self.config = astrbot_config.get("knowledge_db", {})
+ self.astrbot_config = astrbot_config
+ if not os.path.exists(self.db_path):
+ os.makedirs(self.db_path)
+ self.store_insts: Dict[str, Store] = {}
+ for name, cfg in self.config.items():
+ if cfg["strategy"] == "embedding":
+ logger.info(f"加载 Chroma Vector Store:{name}")
+ try:
+ from .store.chroma_db import ChromaVectorStore
+ except ImportError as ie:
+ logger.error(f"{ie} 可能未安装 chromadb 库。")
+ continue
+ self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
+ else:
+ logger.error(f"不支持的策略:{cfg['strategy']}")
+
+
+ async def list_knowledge_db(self) -> List[str]:
+ return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
+
+
+ async def create_knowledge_db(self, name: str, config: Dict):
+ '''
+ config 格式:
+ ```
+ {
+ "strategy": "embedding", # 目前只支持 embedding
+ "chunk_method": {
+ "strategy": "fixed",
+ "chunk_size": 100,
+ "overlap_size": 10
+ },
+ "embedding_config": {
+ "strategy": "openai",
+ "base_url": "",
+ "model": "",
+ "api_key": ""
+ }
+ }
+ ```
+ '''
+ if name in self.config:
+ raise ValueError(f"知识库已存在:{name}")
+
+ self.config[name] = config
+ self.astrbot_config["knowledge_db"] = self.config
+ self.astrbot_config.save_config()
+
+
+ async def insert_record(self, name: str, text: str):
+ if name not in self.store_insts:
+ raise ValueError(f"未找到知识库:{name}")
+
+ ret = []
+ match self.config[name]["chunk_method"]['strategy']:
+ case "fixed":
+ chunk_size = self.config[name]["chunk_method"]["chunk_size"]
+ chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
+ ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
+ case _:
+ pass
+
+ for chunk in ret:
+ await self.store_insts[name].save(chunk)
+
+
+ async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
+ if name not in self.store_insts:
+ raise ValueError(f"未找到知识库:{name}")
+
+ inst = self.store_insts[name]
+ return await inst.query(query, top_n)
+
+
+ def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
+ chunks = []
+ start = 0
+ while start < len(text):
+ end = start + chunk_size
+ chunks.append(text[start:end])
+ start += chunk_size - chunk_overlap
+ return chunks
\ No newline at end of file
diff --git a/astrbot/core/rag/store/__init__.py b/astrbot/core/rag/store/__init__.py
new file mode 100644
index 000000000..cd4a3060a
--- /dev/null
+++ b/astrbot/core/rag/store/__init__.py
@@ -0,0 +1,8 @@
+from typing import List
+
+class Store():
+ async def save(self, text: str):
+ pass
+
+ async def query(self, query: str, top_n: int = 3) -> List[str]:
+ pass
diff --git a/astrbot/core/rag/store/chroma_db.py b/astrbot/core/rag/store/chroma_db.py
new file mode 100644
index 000000000..58ee9d9fb
--- /dev/null
+++ b/astrbot/core/rag/store/chroma_db.py
@@ -0,0 +1,39 @@
+import chromadb
+import uuid
+from typing import List, Dict
+from astrbot.api import logger
+from ..embedding.openai_source import SimpleOpenAIEmbedding
+from . import Store
+
+class ChromaVectorStore(Store):
+ def __init__(self, name: str, embedding_cfg: Dict) -> None:
+ self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
+ self.collection = self.chroma_client.get_or_create_collection(name=name)
+ self.embedding = None
+ if embedding_cfg["strategy"] == "openai":
+ self.embedding = SimpleOpenAIEmbedding(
+ model=embedding_cfg["model"],
+ api_key=embedding_cfg["api_key"],
+ api_base=embedding_cfg.get("base_url", None)
+ )
+
+ async def save(self, text: str, metadata: Dict = None):
+ logger.debug(f"Saving text: {text}")
+ embedding = await self.embedding.get_embedding(text)
+
+ self.collection.upsert(
+ documents=text,
+ metadatas=metadata,
+ ids=str(uuid.uuid4()),
+ embeddings=embedding
+ )
+
+ async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
+ embedding = await self.embedding.get_embedding(query)
+
+ results = self.collection.query(
+ query_embeddings=embedding,
+ n_results=top_n,
+ where=metadata_filter
+ )
+ return results['documents'][0]
diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py
index 108c25250..39ed5baf6 100644
--- a/astrbot/core/star/context.py
+++ b/astrbot/core/star/context.py
@@ -14,6 +14,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
+from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class StarCommand(TypedDict):
full_command_name: str
@@ -39,10 +40,20 @@ class Context:
# back compatibility
_register_tasks: List[Awaitable] = []
- def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
+ def __init__(self,
+ event_queue: Queue,
+ config: AstrBotConfig,
+ db: BaseDatabase,
+ provider_manager: ProviderManager = None,
+ platform_manager: PlatformManager = None,
+ knowledge_db_manager: KnowledgeDBManager = None
+ ):
self._event_queue = event_queue
self._config = config
self._db = db
+ self.provider_manager = provider_manager
+ self.platform_manager = platform_manager
+ self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
for star in star_registry:
@@ -73,7 +84,7 @@ class Context:
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
- handler_module_str=func_obj.__module__,
+ handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
@@ -125,7 +136,7 @@ class Context:
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
- handler_module_str=awaitable.__module__,
+ handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py
index 72b8b160b..0976e59f0 100644
--- a/astrbot/core/star/filter/command.py
+++ b/astrbot/core/star/filter/command.py
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
return False
+ # if len(self.handler_params) == 0 and len(ls) > 1:
+ # # 一定程度避免 LLM 聊天时误判为指令
+ # return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 去除空字符串
diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py
index db0c46bb1..a46c05ae3 100644
--- a/astrbot/core/star/register/star_handler.py
+++ b/astrbot/core/star/register/star_handler.py
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
- handler_module_str=handler.__module__,
+ handler_module_path=handler.__module__,
handler=handler,
event_filters=[]
)
diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py
index 58fe7b8b9..9acfa56e0 100644
--- a/astrbot/core/star/star_handler.py
+++ b/astrbot/core/star/star_handler.py
@@ -1,11 +1,11 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
-from typing import Awaitable, List, Dict
+from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
-
-class StarHandlerRegistry(List):
+T = TypeVar('T', bound='StarHandlerMetadata')
+class StarHandlerRegistry(Generic[T], List[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -26,8 +26,7 @@ class StarHandlerRegistry(List):
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
- return [handler for handler in self if handler.handler_module_str == module_name]
-
+ return [handler for handler in self if handler.handler_module_path == module_name]
star_handlers_registry = StarHandlerRegistry()
@@ -55,7 +54,7 @@ class StarHandlerMetadata():
handler_name: str
'''Handler 的名字,也就是方法名'''
- handler_module_str: str
+ handler_module_path: str
'''Handler 所在的模块路径。'''
handler: Awaitable
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 6f597b143..1a49f9826 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -1,6 +1,7 @@
import inspect
import functools
import os
+import sys
import traceback
import yaml
import logging
@@ -14,9 +15,8 @@ from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
-from astrbot.core.provider.register import llm_tools
-
from .star_handler import star_handlers_registry
+from astrbot.core.provider.register import llm_tools
class PluginManager:
def __init__(
@@ -138,7 +138,18 @@ class PluginManager:
def reload(self):
'''扫描并加载所有的 Star'''
+ for smd in star_registry:
+ logger.debug(f"尝试终止插件 {smd.name} ...")
+ if hasattr(smd.star_cls, "__del__"):
+ smd.star_cls.__del__()
+
star_handlers_registry.clear()
+ star_handlers_registry.star_handlers_map.clear()
+ star_map.clear()
+ star_registry.clear()
+ for key in list(sys.modules.keys()):
+ if key.startswith("data.plugins") or key.startswith("packages"):
+ del sys.modules[key]
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
@@ -225,10 +236,11 @@ class PluginManager:
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
- self._check_plugin_dept_update()
+ # reload the plugin
+ self.reload()
return plugin_path
- def uninstall_plugin(self, plugin_name: str):
+ async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -237,7 +249,20 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
+ # 从 star_registry 和 star_map 中删除
del star_map[plugin.module_path]
+ for i, p in enumerate(star_registry):
+ if p.name == plugin_name:
+ del star_registry[i]
+ break
+ for handler in star_handlers_registry.get_handlers_by_module_name(plugin.module_path):
+ logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
+ star_handlers_registry.remove(handler)
+ keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin.module_path)]
+ for k in keys_to_delete:
+ v = star_handlers_registry.star_handlers_map[k]
+ logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
+ del star_handlers_registry.star_handlers_map[k]
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
@@ -250,6 +275,7 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
+ self.reload()
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
@@ -262,3 +288,4 @@ class PluginManager:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
+
diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py
index 93c7aefbd..02b9dc2da 100644
--- a/astrbot/core/star/updator.py
+++ b/astrbot/core/star/updator.py
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
- logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
- logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
+ logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py
index ed3657531..ade94fa6e 100644
--- a/astrbot/core/zip_updator.py
+++ b/astrbot/core/zip_updator.py
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
- logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
+ logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py
index d22582c55..c43ab7660 100644
--- a/astrbot/dashboard/routes/plugin.py
+++ b/astrbot/dashboard/routes/plugin.py
@@ -53,7 +53,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
- self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -69,7 +68,6 @@ class PluginRoute(Route):
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
- self.core_lifecycle.restart()
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
@@ -80,7 +78,7 @@ class PluginRoute(Route):
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
- self.plugin_manager.uninstall_plugin(plugin_name)
+ await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
@@ -93,9 +91,8 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
- self.core_lifecycle.restart()
- logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
- return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
+ logger.info(f"更新插件 {plugin_name} 成功。")
+ return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
\ No newline at end of file
diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py
index 8fc6fcd53..03f241de4 100644
--- a/astrbot/dashboard/routes/update.py
+++ b/astrbot/dashboard/routes/update.py
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
async def update_project(self):
data = await request.json
version = data.get('version', '')
+ reboot = data.get('reboot', True)
if version == "" or version == "latest":
latest = True
version = ''
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
- threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
- return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
+ if reboot:
+ threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
+ return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
+ else:
+ return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
\ No newline at end of file
diff --git a/main.py b/main.py
index e17ce3b84..fd0b1712f 100644
--- a/main.py
+++ b/main.py
@@ -1,4 +1,3 @@
-
import os
import asyncio
import sys
@@ -42,14 +41,16 @@ async def check_dashboard_files():
return
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
logger.info("开始下载管理面板文件...")
+ ok = False
async with aiohttp.ClientSession() as session:
async with session.get(dashboard_release_url) as resp:
if resp.status != 200:
logger.error(f"下载管理面板文件失败: {resp.status}")
- with open("data/dashboard.zip", "wb") as f:
- f.write(await resp.read())
- logger.info("管理面板文件下载完成。")
- ok = True
+ else:
+ with open("data/dashboard.zip", "wb") as f:
+ f.write(await resp.read())
+ logger.info("管理面板文件下载完成。")
+ ok = True
if not ok:
logger.critical("下载管理面板文件失败")
diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py
index 6df2208cf..0af740faf 100644
--- a/packages/astrbot/main.py
+++ b/packages/astrbot/main.py
@@ -16,6 +16,8 @@ class Main(star.Star):
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
+
+ self.kdb_enabled = False
async def _query_astrbot_notice(self):
try:
@@ -289,7 +291,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
-"""))
+""").use_t2i(False))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
@@ -337,4 +339,33 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
- event.stop_event()
\ No newline at end of file
+ event.stop_event()
+
+ @filter.command_group("kdb")
+ def kdb(self):
+ pass
+
+ @kdb.command("on")
+ async def on_kdb(self, event: AstrMessageEvent):
+ self.kdb_enabled = True
+ curr_kdb_name = self.context.provider_manager.curr_kdb_name
+ if not curr_kdb_name:
+ yield event.plain_result("未载入任何知识库")
+ else:
+ yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
+
+ @kdb.command("off")
+ async def off_kdb(self, event: AstrMessageEvent):
+ self.kdb_enabled = False
+ yield event.plain_result("知识库已关闭")
+
+ @filter.on_llm_request()
+ async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
+ curr_kdb_name = self.context.provider_manager.curr_kdb_name
+ if self.kdb_enabled and curr_kdb_name:
+ mgr = self.context.knowledge_db_manager
+ results = await mgr.retrive_records(curr_kdb_name, req.prompt)
+ if results:
+ req.system_prompt += "\nHere are documents that related to user's query: \n"
+ for result in results:
+ req.system_prompt += f"- {result}\n"
\ No newline at end of file
diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py
new file mode 100644
index 000000000..d4a38ad9b
--- /dev/null
+++ b/tests/test_dashboard.py
@@ -0,0 +1,148 @@
+import pytest
+import os
+from quart import Quart
+from astrbot.dashboard.server import AstrBotDashboard
+from astrbot.core.db.sqlite import SQLiteDatabase
+from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+from astrbot.core import LogBroker
+from astrbot.core.star.star_handler import star_handlers_registry
+from astrbot.core.star.star import star_registry
+
+
+@pytest.fixture(scope="module")
+def core_lifecycle_td():
+ db = SQLiteDatabase("data/data_v3.db")
+ log_broker = LogBroker()
+ core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
+ return core_lifecycle_td
+
+@pytest.fixture(scope="module")
+def app(core_lifecycle_td):
+ db = SQLiteDatabase("data/data_v3.db")
+ server = AstrBotDashboard(core_lifecycle_td, db)
+ return server.app
+
+@pytest.fixture(scope="module")
+def header():
+ return {}
+
+@pytest.mark.asyncio
+async def test_init_core_lifecycle_td(core_lifecycle_td):
+ await core_lifecycle_td.initialize()
+ assert core_lifecycle_td is not None
+
+@pytest.mark.asyncio
+async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict):
+ test_client = app.test_client()
+ response = await test_client.post('/api/auth/login', json={
+ "username": "wrong",
+ "password": "password"
+ })
+ data = await response.get_json()
+ assert data['status'] == 'error'
+
+ response = await test_client.post('/api/auth/login', json={
+ "username": core_lifecycle_td.astrbot_config['dashboard']['username'],
+ "password": core_lifecycle_td.astrbot_config['dashboard']['password']
+ })
+ data = await response.get_json()
+ assert data['status'] == 'ok' and 'token' in data['data']
+ header['Authorization'] = f"Bearer {data['data']['token']}"
+
+@pytest.mark.asyncio
+async def test_get_stat(app: Quart, header: dict):
+ test_client = app.test_client()
+ response = await test_client.get('/api/stat/get')
+ assert response.status_code == 401
+ response = await test_client.get('/api/stat/get', headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok' and 'platform' in data['data']
+
+@pytest.mark.asyncio
+async def test_plugins(app: Quart, header: dict):
+ test_client = app.test_client()
+ # 已经安装的插件
+ response = await test_client.get('/api/plugin/get', headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok'
+
+ # 插件市场
+ response = await test_client.get('/api/plugin/market_list', headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok'
+
+ # 插件安装
+ response = await test_client.post('/api/plugin/install', json={
+ "url": "https://github.com/Soulter/astrbot_plugin_essential"
+ }, headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok'
+ exists = False
+ for md in star_registry:
+ if md.name == "astrbot_plugin_essential":
+ exists = True
+ break
+ assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
+
+ # 插件更新
+ response = await test_client.post('/api/plugin/update', json={
+ "name": "astrbot_plugin_essential"
+ }, headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok'
+
+ # 插件卸载
+ response = await test_client.post('/api/plugin/uninstall', json={
+ "name": "astrbot_plugin_essential"
+ }, headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok'
+ exists = False
+ for md in star_registry:
+ if md.name == "astrbot_plugin_essential":
+ exists = True
+ break
+ assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
+ exists = False
+ for md in star_handlers_registry:
+ if "astrbot_plugin_essential" in md.handler_module_path:
+ exists = True
+ break
+ assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
+
+@pytest.mark.asyncio
+async def test_check_update(app: Quart, header: dict):
+ test_client = app.test_client()
+ response = await test_client.get('/api/update/check', headers=header)
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'success'
+
+@pytest.mark.asyncio
+async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle):
+ global VERSION
+ test_client = app.test_client()
+ os.makedirs("data/astrbot_release", exist_ok=True)
+ core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
+ VERSION = "114.514.1919810"
+ response = await test_client.post('/api/update/do', headers=header, json={
+ "version": "latest"
+ })
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'error' # 已经是最新版本
+
+ response = await test_client.post('/api/update/do', headers=header, json={
+ "version": "v3.4.0",
+ "reboot": False
+ })
+ assert response.status_code == 200
+ data = await response.get_json()
+ assert data['status'] == 'ok'
+ assert os.path.exists("data/astrbot_release/astrbot")
\ No newline at end of file
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 000000000..d2201e448
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,48 @@
+import os
+import sys
+import pytest
+from unittest import mock
+from main import check_env, check_dashboard_files
+
+class _version_info():
+ def __init__(self, major, minor):
+ self.major = major
+ self.minor = minor
+
+def test_check_env(monkeypatch):
+ version_info_correct = _version_info(3, 10)
+ version_info_wrong = _version_info(3, 9)
+ monkeypatch.setattr(sys, 'version_info', version_info_correct)
+ with mock.patch('os.makedirs') as mock_makedirs:
+ check_env()
+ mock_makedirs.assert_any_call("data/config", exist_ok=True)
+ mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
+ mock_makedirs.assert_any_call("data/temp", exist_ok=True)
+
+ monkeypatch.setattr(sys, 'version_info', version_info_wrong)
+ with pytest.raises(SystemExit):
+ check_env()
+
+@pytest.mark.asyncio
+async def test_check_dashboard_files(monkeypatch):
+ monkeypatch.setattr(os.path, 'exists', lambda x: False)
+ async def mock_get(*args, **kwargs):
+ class MockResponse:
+ status = 200
+ async def read(self):
+ return b'content'
+ return MockResponse()
+
+ with mock.patch('aiohttp.ClientSession.get', new=mock_get):
+ with mock.patch('builtins.open', mock.mock_open()) as mock_file:
+ with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
+ async def mock_aenter(_):
+ await check_dashboard_files()
+ mock_file.assert_called_once_with("data/dashboard.zip", "wb")
+ mock_extractall.assert_called_once()
+
+ async def mock_aexit(obj, exc_type, exc, tb):
+ return
+
+ mock_extractall.__aenter__ = mock_aenter
+ mock_extractall.__aexit__ = mock_aexit
\ No newline at end of file
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644
index 000000000..4d90fae88
--- /dev/null
+++ b/tests/test_pipeline.py
@@ -0,0 +1,224 @@
+import pytest
+import logging
+import os
+from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
+from astrbot.core.star import PluginManager
+from astrbot.core.config.astrbot_config import AstrBotConfig
+from astrbot.core.platform.astr_message_event import AstrMessageEvent
+from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
+from astrbot.core.message.message_event_result import MessageChain, ResultContentType
+from astrbot.core.message.components import Plain, At
+from astrbot.core.platform.platform_metadata import PlatformMetadata
+from astrbot.core.platform.manager import PlatformManager
+from astrbot.core.provider.manager import ProviderManager
+from astrbot.core.db.sqlite import SQLiteDatabase
+from astrbot.core.star.context import Context
+from asyncio import Queue
+
+SESSION_ID_IN_WHITELIST = "test_sid_wl"
+SESSION_ID_NOT_IN_WHITELIST = "test_sid"
+TEST_LLM_PROVIDER = {
+ "id": "zhipu_default",
+ "type": "openai_chat_completion",
+ "enable": True,
+ "key": [os.getenv("ZHIPU_API_KEY")],
+ "api_base": "https://open.bigmodel.cn/api/paas/v4/",
+ "model_config": {
+ "model": "glm-4-flash",
+ },
+}
+
+TEST_COMMANDS = [
+ ["help", "已注册的 AstrBot 内置指令"],
+ ["tool ls", "函数工具"],
+ ["tool on websearch", "激活工具"],
+ ["tool off websearch", "停用工具"],
+ ["plugin", "已加载的插件"],
+ ["t2i", "文本转图片模式"],
+ ["sid", "此 ID 可用于设置会话白名单。"],
+ ["op test_op", "授权成功。"],
+ ["deop test_op", "取消授权成功。"],
+ ["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
+ ["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
+ ["provider", "当前载入的 LLM 提供商"],
+ ["reset", "重置成功"],
+ # ["model", "查看、切换提供商模型列表"],
+ ["history", "历史记录:"],
+ ["key", "当前 Key"],
+ ["persona", "[Persona]"]
+]
+
+class FakeAstrMessageEvent(AstrMessageEvent):
+ def __init__(self, abm: AstrBotMessage = None):
+ meta = PlatformMetadata("test_platform", "test")
+ super().__init__(
+ message_str=abm.message_str,
+ message_obj=abm,
+ platform_meta=meta,
+ session_id=abm.session_id
+ )
+
+ async def send(self, message: MessageChain):
+ await super().send(message)
+
+ @staticmethod
+ def create_fake_event(
+ message_str: str,
+ session_id: str = "test_sid",
+ is_at: bool = False,
+ is_group: bool = False,
+ sender_id: str = "123456"
+ ):
+ abm = AstrBotMessage()
+ abm.message_str = message_str
+ abm.group_id = "test"
+ abm.message = [Plain(message_str)]
+ if is_at:
+ abm.message.append(At(qq="bot"))
+ abm.self_id = "bot"
+ abm.sender = MessageMember(sender_id, "mika")
+ abm.timestamp = 1234567890
+ abm.message_id = "test"
+ abm.session_id = session_id
+ if is_group:
+ abm.type = MessageType.GROUP_MESSAGE
+ else:
+ abm.type = MessageType.FRIEND_MESSAGE
+ return FakeAstrMessageEvent(abm)
+
+@pytest.fixture(scope="module")
+def event_queue():
+ return Queue()
+
+@pytest.fixture(scope="module")
+def config():
+ cfg = AstrBotConfig()
+ cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
+ cfg['admins_id'] = ["123456"]
+ cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
+ cfg['provider'] = [TEST_LLM_PROVIDER]
+ return cfg
+
+@pytest.fixture(scope="module")
+def db():
+ return SQLiteDatabase("data/data_v3.db")
+
+@pytest.fixture(scope="module")
+def platform_manager(event_queue, config):
+ return PlatformManager(config, event_queue)
+
+@pytest.fixture(scope="module")
+def provider_manager(config, db):
+ return ProviderManager(config, db)
+
+@pytest.fixture(scope="module")
+def star_context(event_queue, config, db, platform_manager, provider_manager):
+ star_context = Context(event_queue, config, db, provider_manager, platform_manager)
+ return star_context
+
+@pytest.fixture(scope="module")
+def plugin_manager(star_context, config):
+ plugin_manager = PluginManager(star_context, config)
+ plugin_manager.reload()
+ return plugin_manager
+
+@pytest.fixture(scope="module")
+def pipeline_context(config, plugin_manager):
+ return PipelineContext(config, plugin_manager)
+
+@pytest.fixture(scope="module")
+def pipeline_scheduler(pipeline_context):
+ return PipelineScheduler(pipeline_context)
+
+@pytest.mark.asyncio
+async def test_platform_initialization(platform_manager: PlatformManager):
+ await platform_manager.initialize()
+
+@pytest.mark.asyncio
+async def test_provider_initialization(provider_manager: ProviderManager):
+ await provider_manager.initialize()
+
+@pytest.mark.asyncio
+async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
+ await pipeline_scheduler.initialize()
+
+@pytest.mark.asyncio
+async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
+ '''测试唤醒'''
+ # 群聊无 @ 无指令
+ caplog.clear()
+ mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
+ with caplog.at_level(logging.DEBUG):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
+ # 群聊有 @ 无指令
+ mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
+ with caplog.at_level(logging.DEBUG):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
+ # 群聊有指令
+ mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
+ await pipeline_scheduler.execute(mock_event)
+ assert mock_event._has_send_oper is True
+
+@pytest.mark.asyncio
+async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
+ caplog.clear()
+ mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
+ with caplog.at_level(logging.INFO):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
+
+ mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
+ with caplog.at_level(logging.INFO):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
+
+@pytest.mark.asyncio
+async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
+ # 测试默认屏蔽词
+ caplog.clear()
+ mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
+ with caplog.at_level(logging.INFO):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
+ # 测试额外屏蔽词
+ mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
+ with caplog.at_level(logging.INFO):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
+ mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
+ with caplog.at_level(logging.INFO):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("内容安全检查不通过" not in message for message in caplog.messages)
+ # TODO: 测试 百度AI 的内容安全检查
+
+
+@pytest.mark.asyncio
+async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
+ caplog.clear()
+ mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
+ with caplog.at_level(logging.DEBUG):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("请求 LLM" in message for message in caplog.messages)
+ assert mock_event.get_result() is not None
+ assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
+
+@pytest.mark.asyncio
+async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
+ caplog.clear()
+ mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
+ with caplog.at_level(logging.DEBUG):
+ await pipeline_scheduler.execute(mock_event)
+ assert any("请求 LLM" in message for message in caplog.messages)
+ assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
+
+@pytest.mark.asyncio
+async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
+ for command in TEST_COMMANDS:
+ caplog.clear()
+ mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
+ with caplog.at_level(logging.DEBUG):
+ await pipeline_scheduler.execute(mock_event)
+ # assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
+ assert any(command[1] in message for message in caplog.messages)
\ No newline at end of file
diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py
new file mode 100644
index 000000000..8d7b35568
--- /dev/null
+++ b/tests/test_plugin_manager.py
@@ -0,0 +1,80 @@
+import pytest
+import os
+from astrbot.core.star.star_manager import PluginManager
+from astrbot.core.star.star_handler import star_handlers_registry
+from astrbot.core.star.star import star_registry
+from astrbot.core.star.context import Context
+from astrbot.core.config.astrbot_config import AstrBotConfig
+from astrbot.core.db.sqlite import SQLiteDatabase
+from asyncio import Queue
+
+event_queue = Queue()
+
+config = AstrBotConfig()
+
+db = SQLiteDatabase("data/data_v3.db")
+
+star_context = Context(event_queue, config, db)
+
+@pytest.fixture
+def plugin_manager_pm():
+ return PluginManager(star_context, config)
+
+def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
+ assert plugin_manager_pm is not None
+ assert plugin_manager_pm.context is not None
+ assert plugin_manager_pm.config is not None
+
+@pytest.mark.asyncio
+async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
+ success, err_message = plugin_manager_pm.reload()
+ assert success is True
+ assert err_message is None
+ assert len(star_handlers_registry) > 0 # package
+
+@pytest.mark.asyncio
+async def test_plugin_crud(plugin_manager_pm: PluginManager):
+ '''测试插件安装和重载'''
+ os.makedirs("data/plugins", exist_ok=True)
+ test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
+ plugin_path = await plugin_manager_pm.install_plugin(test_repo)
+ exists = False
+ for md in star_registry:
+ if md.name == "astrbot_plugin_essential":
+ exists = True
+ break
+ assert plugin_path is not None
+ assert os.path.exists(plugin_path)
+ assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
+ # shutil.rmtree(plugin_path)
+
+ # install plugin which is not exists
+ with pytest.raises(Exception):
+ plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
+
+ # update
+ await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
+
+ with pytest.raises(Exception):
+ await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
+
+ # uninstall
+ await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
+ assert not os.path.exists(plugin_path)
+ exists = False
+ for md in star_registry:
+ if md.name == "astrbot_plugin_essential":
+ exists = True
+ break
+ assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
+ exists = False
+ for md in star_handlers_registry:
+ if "astrbot_plugin_essential" in md.handler_module_path:
+ exists = True
+ break
+ assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
+
+ with pytest.raises(Exception):
+ await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
+
+ # TODO: file installation