Merge branch 'master' into feat-dify
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
push
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -21,17 +28,16 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
export LLM_MODEL=${{ secrets.LLM_MODEL }}
|
||||
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
|
||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -14,6 +14,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
</a>
|
||||
|
||||
<a href="https://astrbot.soulter.top/">查看文档</a> |
|
||||
|
||||
@@ -8,5 +8,9 @@ os.makedirs("data", exist_ok=True)
|
||||
|
||||
html_renderer = HtmlRenderer()
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
if os.environ.get('TESTING', ""):
|
||||
logger.setLevel('DEBUG')
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
@@ -17,6 +17,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"reply_prefix": "",
|
||||
"forward_threshold": 200,
|
||||
"enable_id_white_list": True,
|
||||
"id_whitelist": [],
|
||||
"id_whitelist_log": True,
|
||||
"wl_ignore_admin_on_group": True,
|
||||
@@ -49,7 +50,8 @@ DEFAULT_CONFIG = {
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": ""
|
||||
"plugin_repo_mirror": "",
|
||||
"knowledge_db": {},
|
||||
}
|
||||
|
||||
|
||||
@@ -162,6 +164,10 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。",
|
||||
},
|
||||
"enable_id_white_list": {
|
||||
"description": "启用 ID 白名单",
|
||||
"type": "bool"
|
||||
},
|
||||
"id_whitelist": {
|
||||
"description": "ID 白名单",
|
||||
"type": "list",
|
||||
@@ -273,7 +279,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"zhipu": {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"type": "zhipu_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
|
||||
@@ -16,6 +16,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
@@ -29,7 +30,10 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("AstrBot v"+ VERSION)
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
self.event_queue = Queue()
|
||||
self.event_queue.closed = False
|
||||
|
||||
@@ -37,9 +41,16 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
|
||||
self.star_context.platform_manager = self.platform_manager
|
||||
self.star_context.provider_manager = self.provider_manager
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.knowledge_db_manager
|
||||
)
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
self.plugin_manager.reload()
|
||||
|
||||
@@ -22,6 +22,9 @@ class LLMRequestSubStage(Stage):
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
|
||||
@@ -27,7 +27,7 @@ class StarRequestSubStage(Stage):
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_str not in star_map:
|
||||
if handler.handler_module_path not in star_map:
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
|
||||
@@ -39,7 +39,7 @@ class StarRequestSubStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
|
||||
try:
|
||||
ready_to_call = handler(event, **params)
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
# 向下兼容
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
|
||||
|
||||
@@ -10,12 +10,16 @@ class WhitelistCheckStage(Stage):
|
||||
'''检查是否在群聊/私聊白名单
|
||||
'''
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
|
||||
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
|
||||
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
|
||||
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
|
||||
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if not self.enable_whitelist_check:
|
||||
return
|
||||
|
||||
# 检查是否在白名单
|
||||
if self.wl_ignore_admin_on_group:
|
||||
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -11,7 +11,7 @@ from botpy import Client
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from .qqofficial_message_event import QQOfficialMessageEvent
|
||||
from ...register import register_platform_adapter
|
||||
@@ -111,6 +111,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
abm.message_id = message.id
|
||||
abm.tag = "qq_official"
|
||||
msg: List[BaseMessageComponent] = []
|
||||
|
||||
|
||||
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
|
||||
if isinstance(message, botpy.message.GroupMessage):
|
||||
@@ -126,7 +127,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
)
|
||||
abm.message_str = message.content.strip()
|
||||
abm.self_id = "unknown_selfid"
|
||||
|
||||
msg.append(At(qq="qq_official"))
|
||||
msg.append(Plain(abm.message_str))
|
||||
if message.attachments:
|
||||
for i in message.attachments:
|
||||
@@ -146,7 +147,7 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
plain_content = message.content.replace(
|
||||
"<@!"+str(abm.self_id)+">", "").strip()
|
||||
msg.append(Plain(plain_content))
|
||||
|
||||
if message.attachments:
|
||||
for i in message.attachments:
|
||||
if i.content_type.startswith("image"):
|
||||
@@ -161,11 +162,14 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
str(message.author.id),
|
||||
str(message.author.username)
|
||||
)
|
||||
msg.append(At(qq="qq_official"))
|
||||
msg.append(Plain(plain_content))
|
||||
|
||||
if isinstance(message, botpy.message.Message):
|
||||
abm.group_id = message.channel_id
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {message_type}")
|
||||
abm.self_id = "qq_official"
|
||||
return abm
|
||||
|
||||
def run(self):
|
||||
|
||||
@@ -18,6 +18,11 @@ class ProviderManager():
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
self.curr_kdb_name = ""
|
||||
kdb_cfg = config.get("knowledge_db", {})
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
for provider_cfg in self.providers_config:
|
||||
if not provider_cfg['enable']:
|
||||
continue
|
||||
@@ -29,6 +34,8 @@ class ProviderManager():
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
@@ -54,6 +61,8 @@ class ProviderManager():
|
||||
|
||||
if len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
else:
|
||||
logger.warning("未启用任何大模型提供商适配器。")
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
@@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
@@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
return llm_response
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import traceback
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
|
||||
class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
# glm-4v-flash 只支持一张图片
|
||||
model: str = model_cfgs.get("model", "")
|
||||
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
|
||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||
logger.debug(context_query)
|
||||
new_context_query_ = []
|
||||
for i in range(0, len(context_query) - 1, 2):
|
||||
if isinstance(context_query[i].get("content", ""), list):
|
||||
continue
|
||||
new_context_query_.append(context_query[i])
|
||||
new_context_query_.append(context_query[i+1])
|
||||
new_context_query_.append(context_query[-1]) # 保留最后一条记录
|
||||
context_query = new_context_query_
|
||||
logger.debug(context_query)
|
||||
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_cfgs
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
class SimpleOpenAIEmbedding():
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
'''
|
||||
获取文本的嵌入
|
||||
'''
|
||||
embedding = await self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return embedding.data[0].embedding
|
||||
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
class KnowledgeDBManager():
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
|
||||
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
'''
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
'''
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]['strategy']:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
@@ -0,0 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
class Store():
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
@@ -0,0 +1,39 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None)
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding,
|
||||
n_results=top_n,
|
||||
where=metadata_filter
|
||||
)
|
||||
return results['documents'][0]
|
||||
@@ -14,6 +14,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
class StarCommand(TypedDict):
|
||||
full_command_name: str
|
||||
@@ -39,10 +40,20 @@ class Context:
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
def __init__(self,
|
||||
event_queue: Queue,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
for star in star_registry:
|
||||
@@ -73,7 +84,7 @@ class Context:
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_str=func_obj.__module__,
|
||||
handler_module_path=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
@@ -125,7 +136,7 @@ class Context:
|
||||
event_type=EventType.AdapterMessageEvent,
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_str=awaitable.__module__,
|
||||
handler_module_path=awaitable.__module__,
|
||||
handler=awaitable,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
|
||||
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0]:
|
||||
return False
|
||||
# if len(self.handler_params) == 0 and len(ls) > 1:
|
||||
# # 一定程度避免 LLM 聊天时误判为指令
|
||||
# return False
|
||||
# params_str = message_str[len(self.command_name):].strip()
|
||||
ls = ls[1:]
|
||||
# 去除空字符串
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
|
||||
event_type=event_type,
|
||||
handler_full_name=handler_full_name,
|
||||
handler_name=handler.__name__,
|
||||
handler_module_str=handler.__module__,
|
||||
handler_module_path=handler.__module__,
|
||||
handler=handler,
|
||||
event_filters=[]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, List, Dict
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
|
||||
|
||||
class StarHandlerRegistry(List):
|
||||
T = TypeVar('T', bound='StarHandlerMetadata')
|
||||
class StarHandlerRegistry(Generic[T], List[T]):
|
||||
'''用于存储所有的 Star Handler'''
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -26,8 +26,7 @@ class StarHandlerRegistry(List):
|
||||
|
||||
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
|
||||
'''通过模块名获取 Handler'''
|
||||
return [handler for handler in self if handler.handler_module_str == module_name]
|
||||
|
||||
return [handler for handler in self if handler.handler_module_path == module_name]
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
@@ -55,7 +54,7 @@ class StarHandlerMetadata():
|
||||
handler_name: str
|
||||
'''Handler 的名字,也就是方法名'''
|
||||
|
||||
handler_module_str: str
|
||||
handler_module_path: str
|
||||
'''Handler 所在的模块路径。'''
|
||||
|
||||
handler: Awaitable
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
@@ -14,9 +15,8 @@ from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
class PluginManager:
|
||||
def __init__(
|
||||
@@ -138,7 +138,18 @@ class PluginManager:
|
||||
|
||||
def reload(self):
|
||||
'''扫描并加载所有的 Star'''
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
smd.star_cls.__del__()
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_handlers_registry.star_handlers_map.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
@@ -225,10 +236,11 @@ class PluginManager:
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
plugin_path = await self.updator.install(repo_url)
|
||||
self._check_plugin_dept_update()
|
||||
# reload the plugin
|
||||
self.reload()
|
||||
return plugin_path
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str):
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
@@ -237,7 +249,20 @@ class PluginManager:
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
del star_map[plugin.module_path]
|
||||
for i, p in enumerate(star_registry):
|
||||
if p.name == plugin_name:
|
||||
del star_registry[i]
|
||||
break
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(plugin.module_path):
|
||||
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
|
||||
star_handlers_registry.remove(handler)
|
||||
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin.module_path)]
|
||||
for k in keys_to_delete:
|
||||
v = star_handlers_registry.star_handlers_map[k]
|
||||
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
@@ -250,6 +275,7 @@ class PluginManager:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||
|
||||
await self.updator.update(plugin)
|
||||
self.reload()
|
||||
|
||||
def install_plugin_from_file(self, zip_file_path: str):
|
||||
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
|
||||
@@ -262,3 +288,4 @@ class PluginManager:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
|
||||
self._check_plugin_dept_update()
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
|
||||
|
||||
files = os.listdir(os.path.join(target_dir, update_dir))
|
||||
for f in files:
|
||||
logger.info(f"移动更新文件/目录: {f}")
|
||||
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
|
||||
if os.path.exists(os.path.join(target_dir, f)):
|
||||
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
|
||||
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
|
||||
|
||||
try:
|
||||
logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
|
||||
os.remove(zip_path)
|
||||
except BaseException:
|
||||
|
||||
@@ -111,7 +111,7 @@ class RepoZipUpdator():
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
@@ -53,7 +53,6 @@ class PluginRoute(Route):
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {repo_url} 成功。")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
@@ -69,7 +68,6 @@ class PluginRoute(Route):
|
||||
await file.save(file_path)
|
||||
self.plugin_manager.install_plugin_from_file(file_path)
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
self.core_lifecycle.restart()
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -80,7 +78,7 @@ class PluginRoute(Route):
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
await self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response().ok(None, "卸载成功").__dict__
|
||||
except Exception as e:
|
||||
@@ -93,9 +91,8 @@ class PluginRoute(Route):
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
await self.plugin_manager.update_plugin(plugin_name)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
|
||||
logger.info(f"更新插件 {plugin_name} 成功。")
|
||||
return Response().ok(None, "更新成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
|
||||
async def update_project(self):
|
||||
data = await request.json
|
||||
version = data.get('version', '')
|
||||
reboot = data.get('reboot', True)
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ''
|
||||
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
|
||||
latest = False
|
||||
try:
|
||||
await self.astrbot_updator.update(latest=latest, version=version)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
if reboot:
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
else:
|
||||
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
@@ -42,14 +41,16 @@ async def check_dashboard_files():
|
||||
return
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
logger.info("开始下载管理面板文件...")
|
||||
ok = False
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(dashboard_release_url) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"下载管理面板文件失败: {resp.status}")
|
||||
with open("data/dashboard.zip", "wb") as f:
|
||||
f.write(await resp.read())
|
||||
logger.info("管理面板文件下载完成。")
|
||||
ok = True
|
||||
else:
|
||||
with open("data/dashboard.zip", "wb") as f:
|
||||
f.write(await resp.read())
|
||||
logger.info("管理面板文件下载完成。")
|
||||
ok = True
|
||||
|
||||
if not ok:
|
||||
logger.critical("下载管理面板文件失败")
|
||||
|
||||
@@ -16,6 +16,8 @@ class Main(star.Star):
|
||||
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
|
||||
self.identifier = cfg['provider_settings']['identifier']
|
||||
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
|
||||
|
||||
self.kdb_enabled = False
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
@@ -289,7 +291,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
- 重置 LLM 会话(保留人格): /reset p
|
||||
|
||||
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
|
||||
"""))
|
||||
""").use_t2i(False))
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
@@ -337,4 +339,33 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
|
||||
async def other_message(self, event: AstrMessageEvent):
|
||||
print("triggered")
|
||||
event.stop_event()
|
||||
event.stop_event()
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
pass
|
||||
|
||||
@kdb.command("on")
|
||||
async def on_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = True
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if not curr_kdb_name:
|
||||
yield event.plain_result("未载入任何知识库")
|
||||
else:
|
||||
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
|
||||
|
||||
@kdb.command("off")
|
||||
async def off_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = False
|
||||
yield event.plain_result("知识库已关闭")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if self.kdb_enabled and curr_kdb_name:
|
||||
mgr = self.context.knowledge_db_manager
|
||||
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
|
||||
if results:
|
||||
req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
for result in results:
|
||||
req.system_prompt += f"- {result}\n"
|
||||
@@ -0,0 +1,148 @@
|
||||
import pytest
|
||||
import os
|
||||
from quart import Quart
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def core_lifecycle_td():
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
log_broker = LogBroker()
|
||||
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
|
||||
return core_lifecycle_td
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app(core_lifecycle_td):
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
server = AstrBotDashboard(core_lifecycle_td, db)
|
||||
return server.app
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def header():
|
||||
return {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_core_lifecycle_td(core_lifecycle_td):
|
||||
await core_lifecycle_td.initialize()
|
||||
assert core_lifecycle_td is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post('/api/auth/login', json={
|
||||
"username": "wrong",
|
||||
"password": "password"
|
||||
})
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'error'
|
||||
|
||||
response = await test_client.post('/api/auth/login', json={
|
||||
"username": core_lifecycle_td.astrbot_config['dashboard']['username'],
|
||||
"password": core_lifecycle_td.astrbot_config['dashboard']['password']
|
||||
})
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok' and 'token' in data['data']
|
||||
header['Authorization'] = f"Bearer {data['data']['token']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stat(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get('/api/stat/get')
|
||||
assert response.status_code == 401
|
||||
response = await test_client.get('/api/stat/get', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok' and 'platform' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
# 已经安装的插件
|
||||
response = await test_client.get('/api/plugin/get', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件市场
|
||||
response = await test_client.get('/api/plugin/market_list', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件安装
|
||||
response = await test_client.post('/api/plugin/install', json={
|
||||
"url": "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
|
||||
# 插件更新
|
||||
response = await test_client.post('/api/plugin/update', json={
|
||||
"name": "astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件卸载
|
||||
response = await test_client.post('/api/plugin/uninstall', json={
|
||||
"name": "astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_update(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get('/api/update/check', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'success'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
global VERSION
|
||||
test_client = app.test_client()
|
||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
||||
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
||||
VERSION = "114.514.1919810"
|
||||
response = await test_client.post('/api/update/do', headers=header, json={
|
||||
"version": "latest"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'error' # 已经是最新版本
|
||||
|
||||
response = await test_client.post('/api/update/do', headers=header, json={
|
||||
"version": "v3.4.0",
|
||||
"reboot": False
|
||||
})
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
assert os.path.exists("data/astrbot_release/astrbot")
|
||||
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from main import check_env, check_dashboard_files
|
||||
|
||||
class _version_info():
|
||||
def __init__(self, major, minor):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
|
||||
def test_check_env(monkeypatch):
|
||||
version_info_correct = _version_info(3, 10)
|
||||
version_info_wrong = _version_info(3, 9)
|
||||
monkeypatch.setattr(sys, 'version_info', version_info_correct)
|
||||
with mock.patch('os.makedirs') as mock_makedirs:
|
||||
check_env()
|
||||
mock_makedirs.assert_any_call("data/config", exist_ok=True)
|
||||
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
|
||||
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(sys, 'version_info', version_info_wrong)
|
||||
with pytest.raises(SystemExit):
|
||||
check_env()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files(monkeypatch):
|
||||
monkeypatch.setattr(os.path, 'exists', lambda x: False)
|
||||
async def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
status = 200
|
||||
async def read(self):
|
||||
return b'content'
|
||||
return MockResponse()
|
||||
|
||||
with mock.patch('aiohttp.ClientSession.get', new=mock_get):
|
||||
with mock.patch('builtins.open', mock.mock_open()) as mock_file:
|
||||
with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
|
||||
async def mock_aenter(_):
|
||||
await check_dashboard_files()
|
||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
||||
mock_extractall.assert_called_once()
|
||||
|
||||
async def mock_aexit(obj, exc_type, exc, tb):
|
||||
return
|
||||
|
||||
mock_extractall.__aenter__ = mock_aenter
|
||||
mock_extractall.__aexit__ = mock_aexit
|
||||
@@ -0,0 +1,224 @@
|
||||
import pytest
|
||||
import logging
|
||||
import os
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core.message.components import Plain, At
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
||||
TEST_LLM_PROVIDER = {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
}
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
["t2i", "文本转图片模式"],
|
||||
["sid", "此 ID 可用于设置会话白名单。"],
|
||||
["op test_op", "授权成功。"],
|
||||
["deop test_op", "取消授权成功。"],
|
||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
||||
["provider", "当前载入的 LLM 提供商"],
|
||||
["reset", "重置成功"],
|
||||
# ["model", "查看、切换提供商模型列表"],
|
||||
["history", "历史记录:"],
|
||||
["key", "当前 Key"],
|
||||
["persona", "[Persona]"]
|
||||
]
|
||||
|
||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, abm: AstrBotMessage = None):
|
||||
meta = PlatformMetadata("test_platform", "test")
|
||||
super().__init__(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=meta,
|
||||
session_id=abm.session_id
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_event(
|
||||
message_str: str,
|
||||
session_id: str = "test_sid",
|
||||
is_at: bool = False,
|
||||
is_group: bool = False,
|
||||
sender_id: str = "123456"
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
abm.message_str = message_str
|
||||
abm.group_id = "test"
|
||||
abm.message = [Plain(message_str)]
|
||||
if is_at:
|
||||
abm.message.append(At(qq="bot"))
|
||||
abm.self_id = "bot"
|
||||
abm.sender = MessageMember(sender_id, "mika")
|
||||
abm.timestamp = 1234567890
|
||||
abm.message_id = "test"
|
||||
abm.session_id = session_id
|
||||
if is_group:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return FakeAstrMessageEvent(abm)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
cfg = AstrBotConfig()
|
||||
cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
|
||||
cfg['admins_id'] = ["123456"]
|
||||
cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
|
||||
cfg['provider'] = [TEST_LLM_PROVIDER]
|
||||
return cfg
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_manager(event_queue, config):
|
||||
return PlatformManager(config, event_queue)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def provider_manager(config, db):
|
||||
return ProviderManager(config, db)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
||||
return star_context
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plugin_manager(star_context, config):
|
||||
plugin_manager = PluginManager(star_context, config)
|
||||
plugin_manager.reload()
|
||||
return plugin_manager
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_context(config, plugin_manager):
|
||||
return PipelineContext(config, plugin_manager)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_scheduler(pipeline_context):
|
||||
return PipelineScheduler(pipeline_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
||||
await platform_manager.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
||||
await provider_manager.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
||||
await pipeline_scheduler.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
'''测试唤醒'''
|
||||
# 群聊无 @ 无指令
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
|
||||
# 群聊有 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
||||
# 群聊有指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert mock_event._has_send_oper is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
# 测试额外屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
||||
# TODO: 测试 百度AI 的内容安全检查
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert mock_event.get_result() is not None
|
||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
||||
assert any(command[1] in message for message in caplog.messages)
|
||||
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
import os
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from asyncio import Queue
|
||||
|
||||
event_queue = Queue()
|
||||
|
||||
config = AstrBotConfig()
|
||||
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
star_context = Context(event_queue, config, db)
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_manager_pm():
|
||||
return PluginManager(star_context, config)
|
||||
|
||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
assert plugin_manager_pm is not None
|
||||
assert plugin_manager_pm.context is not None
|
||||
assert plugin_manager_pm.config is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
success, err_message = plugin_manager_pm.reload()
|
||||
assert success is True
|
||||
assert err_message is None
|
||||
assert len(star_handlers_registry) > 0 # package
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
'''测试插件安装和重载'''
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert plugin_path is not None
|
||||
assert os.path.exists(plugin_path)
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
# shutil.rmtree(plugin_path)
|
||||
|
||||
# install plugin which is not exists
|
||||
with pytest.raises(Exception):
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
|
||||
|
||||
# update
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# uninstall
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
||||
assert not os.path.exists(plugin_path)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# TODO: file installation
|
||||
Reference in New Issue
Block a user