Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f8ab40eb39 | |||
| 55e9233b93 | |||
| b7277b51fd | |||
| 1fa9111b2b | |||
| 90a9e496d9 | |||
| 2a7dce1eb0 | |||
| 0c0841cc03 | |||
| 4c9fe016bf | |||
| acc90f140c | |||
| 68a7bc3930 | |||
| 12ea64be0e | |||
| 7f30a673f7 | |||
| 897e100c32 | |||
| 0d4ad5cb31 | |||
| b124bd0d0e | |||
| 6bc2f84602 | |||
| d787a28c40 | |||
| 6b078a5731 | |||
| 17dddbfe21 | |||
| 3ff3c9e144 | |||
| f5a37d82cc | |||
| d3d428dc9d | |||
| 8dc8c5b5dc | |||
| e6b06f914b | |||
| 4dc502a8b6 | |||
| b1d1a13d5f | |||
| 75cc4cac5a | |||
| 1b7e4fbbdc | |||
| 9789e2f6c1 | |||
| b8fb0bee24 | |||
| 419f77e245 | |||
| 59b1c3473b | |||
| 6db58ca375 | |||
| 4832b342b0 | |||
| 6cec542402 | |||
| 9644791783 | |||
| 5031c307d1 | |||
| aa49539e3e | |||
| 7b4118493b | |||
| d1cc9ba4ce | |||
| e0e92139d7 | |||
| 62039392bb | |||
| b72c69892e | |||
| e6205e9aad | |||
| b8a6fb1720 | |||
| 7c06d82f27 | |||
| d92cb0f500 | |||
| 7fa72f2fe9 |
@@ -2,6 +2,7 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
name: Auto Release
|
||||
|
||||
@@ -14,6 +15,15 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Dashboard Build
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
|
||||
echo ${{ github.ref_name }} > dist/assets/version
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Fetch Changelog
|
||||
run: |
|
||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||
@@ -21,4 +31,5 @@ jobs:
|
||||
- name: Create Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
bodyFile: ${{ env.changelog }}
|
||||
bodyFile: ${{ env.changelog }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
@@ -1,7 +1,14 @@
|
||||
name: Run tests and upload coverage
|
||||
|
||||
on:
|
||||
push
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
paths-ignore:
|
||||
- 'README.md'
|
||||
- 'changelogs/**'
|
||||
- 'dashboard/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
@@ -21,17 +28,16 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install pytest pytest-cov pytest-asyncio
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir temp
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
export LLM_MODEL=${{ secrets.LLM_MODEL }}
|
||||
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
|
||||
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v
|
||||
mkdir data
|
||||
mkdir data/plugins
|
||||
mkdir data/config
|
||||
mkdir data/temp
|
||||
export TESTING=true
|
||||
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
|
||||
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
|
||||
|
||||
- name: Upload results to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
|
||||
@@ -14,9 +14,10 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
</a>
|
||||
|
||||
<a href="https://astrbot.soulter.top/">查看文档</a> |
|
||||
<a href="https://astrbot.lwl.lol/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
@@ -24,7 +25,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
## ✨ 多消息平台部署
|
||||
|
||||
1. QQ 群、QQ 频道、微信、Telegram。
|
||||
1. QQ 群、QQ 频道、微信个人号、Telegram。
|
||||
2. 支持文本转图片,Markdown 渲染。
|
||||
|
||||
## ✨ 多 LLM 配置
|
||||
@@ -33,7 +34,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
2. 支持 OneAPI 等分发平台。
|
||||
3. 支持 LLMTuner 载入微调模型。
|
||||
4. 支持 Ollama 载入自部署模型。
|
||||
4. 支持网页搜索(Web Search)。
|
||||
4. 支持网页搜索(Web Search)、自然语言待办提醒。
|
||||
|
||||
## ✨ 管理面板
|
||||
|
||||
@@ -42,6 +43,38 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
3. 简单的信息统计
|
||||
4. 插件管理
|
||||
|
||||
## ✨ 支持 Dify
|
||||
|
||||
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流!
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
|
||||
_✨ 自然语言待办事项 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ 插件系统——部分插件展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/caadf2bd-a0ee-43d0-a95e-566d63e3e34d" height=330>
|
||||
<img src="https://github.com/user-attachments/assets/b418f281-e920-49db-9fe1-d6a13ce28a84" height=350>
|
||||
|
||||
_✨ 管理面板 ✨_
|
||||
|
||||
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
|
||||
@@ -2,6 +2,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
__all__ = [
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"personalities",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"sp"
|
||||
]
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
from .log import LogManager, LogBroker
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
|
||||
@@ -8,5 +9,10 @@ os.makedirs("data", exist_ok=True)
|
||||
|
||||
html_renderer = HtmlRenderer()
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
if os.environ.get('TESTING', ""):
|
||||
logger.setLevel('DEBUG')
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
sp = SharedPreferences() # 简单的偏好设置存储
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.2"
|
||||
VERSION = "3.4.3"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -50,7 +50,8 @@ DEFAULT_CONFIG = {
|
||||
"log_level": "INFO",
|
||||
"t2i_endpoint": "",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": ""
|
||||
"plugin_repo_mirror": "",
|
||||
"knowledge_db": {},
|
||||
}
|
||||
|
||||
|
||||
@@ -230,10 +231,10 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"provider_group": {
|
||||
"name": "大语言模型",
|
||||
"name": "服务提供商",
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"description": "大语言模型配置",
|
||||
"description": "服务提供商配置",
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"openai": {
|
||||
@@ -256,7 +257,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
},
|
||||
"gemini": {
|
||||
"gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
@@ -266,6 +267,16 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
},
|
||||
"gemini(googlegenai原生)": {
|
||||
"id": "gemini_default",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
},
|
||||
},
|
||||
"deepseek": {
|
||||
"id": "deepseek_default",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -278,7 +289,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"zhipu": {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"type": "zhipu_chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
@@ -295,6 +306,15 @@ CONFIG_METADATA_2 = {
|
||||
"llmtuner_template": "",
|
||||
"finetuning_type": "lora",
|
||||
"quantization_bit": 4,
|
||||
},
|
||||
"dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
"dify_api_base": "https://api.dify.ai/v1",
|
||||
"dify_workflow_output_key": "",
|
||||
}
|
||||
},
|
||||
"items": {
|
||||
@@ -366,6 +386,27 @@ CONFIG_METADATA_2 = {
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
"hint": "Dify API Key。此项必填。",
|
||||
},
|
||||
"dify_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
|
||||
},
|
||||
"dify_api_type": {
|
||||
"description": "Dify 应用类型",
|
||||
"type": "string",
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
||||
"options": ["chat", "agent", "workflow"],
|
||||
},
|
||||
"dify_workflow_output_key": {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
"type": "string",
|
||||
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
||||
}
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
|
||||
@@ -16,6 +16,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
@@ -29,7 +30,10 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("AstrBot v"+ VERSION)
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config['log_level'])
|
||||
self.event_queue = Queue()
|
||||
self.event_queue.closed = False
|
||||
|
||||
@@ -37,12 +41,19 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
|
||||
self.star_context.platform_manager = self.platform_manager
|
||||
self.star_context.provider_manager = self.provider_manager
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.knowledge_db_manager
|
||||
)
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
self.plugin_manager.reload()
|
||||
await self.plugin_manager.reload()
|
||||
'''扫描、注册插件、实例化插件类'''
|
||||
|
||||
await self.provider_manager.initialize()
|
||||
@@ -81,6 +92,8 @@ class AstrBotCoreLifecycle:
|
||||
self.event_queue.closed = True
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
|
||||
for task in self.curr_tasks:
|
||||
try:
|
||||
|
||||
@@ -54,6 +54,7 @@ class ComponentType(Enum):
|
||||
CardImage = "CardImage"
|
||||
TTS = "TTS"
|
||||
Unknown = "Unknown"
|
||||
File = "File"
|
||||
|
||||
|
||||
class BaseMessageComponent(BaseModel):
|
||||
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
|
||||
def toString(self):
|
||||
return ""
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
'''
|
||||
目前此消息段只适配了 Napcat。
|
||||
'''
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
|
||||
|
||||
ComponentTypes = {
|
||||
"plain": Plain,
|
||||
@@ -441,5 +453,6 @@ ComponentTypes = {
|
||||
"json": Json,
|
||||
"cardimage": CardImage,
|
||||
"tts": TTS,
|
||||
"unknown": Unknown
|
||||
"unknown": Unknown,
|
||||
'file': File,
|
||||
}
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
'''
|
||||
Dify 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
|
||||
class DifyRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if provider.meta().type != "dify":
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
req.session_id = event.session_id
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
||||
yield # rick roll
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
|
||||
return
|
||||
@@ -1,3 +1,6 @@
|
||||
'''
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
@@ -15,10 +18,13 @@ class LLMRequestSubStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if provider is None:
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
@@ -38,6 +44,9 @@ class LLMRequestSubStage(Stage):
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if not req.prompt:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
@@ -48,7 +57,9 @@ class LLMRequestSubStage(Stage):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
logger.debug(f"请求 LLM:{req.__dict__}")
|
||||
logger.debug(f"提供商请求 Payload: {req.__dict__}")
|
||||
if _nested:
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
@@ -82,7 +93,7 @@ class LLMRequestSubStage(Stage):
|
||||
for tool_name, tool_result in function_calling_result.items():
|
||||
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
|
||||
req.prompt += extra_prompt
|
||||
async for _ in self.process(event):
|
||||
async for _ in self.process(event, _nested=True):
|
||||
yield
|
||||
|
||||
except BaseException as e:
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
'''
|
||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||
'''
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
@@ -24,7 +27,7 @@ class StarRequestSubStage(Stage):
|
||||
for handler in activated_handlers:
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_str not in star_map:
|
||||
if handler.handler_module_path not in star_map:
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
|
||||
@@ -36,7 +39,7 @@ class StarRequestSubStage(Stage):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||
event.set_result(MessageEventResult().message(ret))
|
||||
yield
|
||||
event.clear_result()
|
||||
|
||||
@@ -3,6 +3,7 @@ from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from .method.dify_request import DifyRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
@@ -20,11 +21,15 @@ class ProcessStage(Stage):
|
||||
|
||||
self.star_request_sub_stage = StarRequestSubStage()
|
||||
await self.star_request_sub_stage.initialize(ctx)
|
||||
|
||||
self.dify_request_sub_stage = DifyRequestSubStage()
|
||||
await self.dify_request_sub_stage.initialize(ctx)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''处理事件
|
||||
'''
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
# 有插件 Handler 被激活
|
||||
if activated_handlers:
|
||||
async for resp in self.star_request_sub_stage.process(event):
|
||||
# 生成器返回值处理
|
||||
@@ -36,10 +41,18 @@ class ProcessStage(Stage):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
if not event._has_send_oper:
|
||||
'''当没有发送操作'''
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
# 调用提供商相关请求
|
||||
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
return
|
||||
|
||||
if not event._has_send_oper and event.is_at_or_wake_command:
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
match provider.meta().type:
|
||||
case "dify":
|
||||
async for _ in self.dify_request_sub_stage.process(event):
|
||||
yield
|
||||
case _:
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
@@ -22,4 +22,6 @@ class RespondStage(Stage):
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
await handler.handler(event)
|
||||
|
||||
event.clear_result()
|
||||
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
|
||||
try:
|
||||
ready_to_call = handler(event, **params)
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
# 向下兼容
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ class WakingCheckStage(Stage):
|
||||
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
|
||||
break
|
||||
is_wake = True
|
||||
event.is_at_or_wake_command = True
|
||||
event.is_wake = True
|
||||
event.message_str = event.message_str[len(wake_prefix) :].strip()
|
||||
break
|
||||
@@ -60,11 +61,13 @@ class WakingCheckStage(Stage):
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
wake_prefix = ""
|
||||
event.is_at_or_wake_command = True
|
||||
break
|
||||
# 检查是否是私聊
|
||||
if event.is_private_chat():
|
||||
is_wake = True
|
||||
event.is_wake = True
|
||||
event.is_at_or_wake_command = True
|
||||
wake_prefix = ""
|
||||
|
||||
# 检查插件的 handler filter
|
||||
|
||||
@@ -35,7 +35,8 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.platform_meta = platform_meta
|
||||
self.session_id = session_id
|
||||
self.role = "member"
|
||||
self.is_wake = False
|
||||
self.is_wake = False # 是否通过 WakingStage
|
||||
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True)
|
||||
self._extras = {}
|
||||
self.session = MessageSesion(
|
||||
platform_name=platform_meta.name,
|
||||
|
||||
@@ -2,4 +2,6 @@ from dataclasses import dataclass
|
||||
@dataclass
|
||||
class PlatformMetadata():
|
||||
name: str # 平台的名称
|
||||
description: str # 平台的描述
|
||||
description: str # 平台的描述
|
||||
|
||||
default_config_tmpl: dict = None # 平台的默认配置模板
|
||||
@@ -7,15 +7,26 @@ platform_registry: List[PlatformMetadata] = []
|
||||
platform_cls_map: Dict[str, Type] = {}
|
||||
'''维护了平台适配器名称和适配器类的映射'''
|
||||
|
||||
def register_platform_adapter(adapter_name: str, desc: str):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
|
||||
'''用于注册平台适配器的带参装饰器。
|
||||
|
||||
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
|
||||
'''
|
||||
def decorator(cls):
|
||||
if adapter_name in platform_cls_map:
|
||||
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
|
||||
|
||||
# 添加必备选项
|
||||
if default_config_tmpl:
|
||||
if 'type' not in default_config_tmpl:
|
||||
default_config_tmpl['type'] = adapter_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
default_config_tmpl=default_config_tmpl
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from .aiocqhttp_message_event import *
|
||||
from astrbot.api.message_components import *
|
||||
from .aiocqhttp_message_event import * # noqa: F403
|
||||
from astrbot.api.message_components import * # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
|
||||
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
|
||||
class AiocqhttpAdapter(Platform):
|
||||
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
|
||||
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
async def convert_message(self, event: Event) -> AstrBotMessage:
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.tag = "aiocqhttp"
|
||||
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
|
||||
a = None
|
||||
if t == 'text':
|
||||
message_str += m['data']['text'].strip()
|
||||
a = ComponentTypes[t](**m['data'])
|
||||
elif t == 'file':
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
|
||||
if not ret.get('file', None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret['file']):
|
||||
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
|
||||
|
||||
m['data'] = {
|
||||
"file": ret['file'],
|
||||
"name": ret['file_name']
|
||||
}
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
a = ComponentTypes[t](**m['data']) # noqa: F405
|
||||
abm.message.append(a)
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
|
||||
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
|
||||
@self.bot.on_message('group')
|
||||
async def group(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@self.bot.on_message('private')
|
||||
async def private(event: Event):
|
||||
abm = self.convert_message(event)
|
||||
abm = await self.convert_message(event)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
|
||||
|
||||
@@ -31,11 +31,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
|
||||
payload['media'] = media
|
||||
payload['msg_type'] = 7
|
||||
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
|
||||
case botpy.message.C2CMessage:
|
||||
if image_base64:
|
||||
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
|
||||
payload['media'] = media
|
||||
payload['msg_type'] = 7
|
||||
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
|
||||
case botpy.message.Message:
|
||||
if image_path:
|
||||
@@ -73,9 +75,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
plain_text += i.text
|
||||
elif isinstance(i, Image) and not image_base64:
|
||||
if i.file and i.file.startswith("file:///"):
|
||||
image_base64 = file_to_base64(i.file[8:])
|
||||
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
|
||||
image_file_path = i.file[8:]
|
||||
elif i.file and i.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(i.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
|
||||
return plain_text, image_base64, image_file_path
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
import time
|
||||
import uuid
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
|
||||
self.start_time = int(time.time())
|
||||
return self._run()
|
||||
|
||||
|
||||
|
||||
async def _run(self):
|
||||
await self.client.init()
|
||||
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
@@ -32,9 +32,9 @@ class ProviderRequest():
|
||||
class LLMResponse:
|
||||
role: str
|
||||
'''角色'''
|
||||
completion_text: str = None
|
||||
completion_text: str = ""
|
||||
'''LLM 返回的文本'''
|
||||
tools_call_args: List[Dict[str, any]] = None
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
'''工具调用参数'''
|
||||
tools_call_name: List[str] = None
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
'''工具调用名称'''
|
||||
@@ -14,6 +14,7 @@ class FuncTool:
|
||||
parameters: Dict
|
||||
description: str
|
||||
handler: Awaitable
|
||||
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
|
||||
active: bool = True
|
||||
'''是否激活'''
|
||||
@@ -100,10 +101,29 @@ class FuncCall:
|
||||
}
|
||||
)
|
||||
return _l
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> Dict:
|
||||
declarations = {}
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
tools.append(
|
||||
{
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
)
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"name": f["name"],
|
||||
@@ -169,3 +189,10 @@ class FuncCall:
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
@@ -5,7 +5,7 @@ from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from collections import defaultdict
|
||||
from .register import provider_cls_map, llm_tools
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
class ProviderManager():
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
@@ -18,6 +18,11 @@ class ProviderManager():
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
self.curr_kdb_name = ""
|
||||
kdb_cfg = config.get("knowledge_db", {})
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
for provider_cfg in self.providers_config:
|
||||
if not provider_cfg['enable']:
|
||||
continue
|
||||
@@ -29,9 +34,15 @@ class ProviderManager():
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
@@ -39,19 +50,31 @@ class ProviderManager():
|
||||
if not provider_config['enable']:
|
||||
continue
|
||||
if provider_config['type'] not in provider_cls_map:
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
cls_type = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
try:
|
||||
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
self.provider_insts.append(inst)
|
||||
if selected_provider_id == provider_config['id']:
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 失败:{e}")
|
||||
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
|
||||
|
||||
if len(self.provider_insts) > 0:
|
||||
if len(self.provider_insts) > 0 and not self.curr_provider_inst:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何提供商适配器。")
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
return self.provider_insts
|
||||
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
await provider_inst.terminate()
|
||||
@@ -0,0 +1,131 @@
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Dify API Key 不能为空。")
|
||||
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_client = DifyAPIClient(self.api_key, api_base)
|
||||
self.api_type = provider_config.get("dify_api_type", "")
|
||||
if not self.api_type:
|
||||
raise Exception("Dify API 类型不能为空。")
|
||||
self.model_name = "dify"
|
||||
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
|
||||
|
||||
self.conversation_ids = {}
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
result = ""
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
file_response = await self.api_client.file_upload(image_path, user=session_id)
|
||||
if 'id' not in file_response:
|
||||
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
|
||||
continue
|
||||
files_payload.append({
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response['id'],
|
||||
})
|
||||
else:
|
||||
# TODO: 处理更多情况
|
||||
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
|
||||
|
||||
logger.debug(files_payload)
|
||||
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk['event'] == "message" or \
|
||||
chunk['event'] == "agent_message":
|
||||
result += chunk['answer']
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk['conversation_id']
|
||||
conversation_id = chunk['conversation_id']
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
"astrbot_text_query": prompt,
|
||||
"astrbot_session_id": session_id
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload
|
||||
):
|
||||
match chunk['event']:
|
||||
case "workflow_started":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
|
||||
case "node_finished":
|
||||
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
|
||||
case "workflow_finished":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
|
||||
if chunk['data']['error']:
|
||||
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
if self.workflow_output_key not in chunk['data']['outputs']:
|
||||
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
|
||||
result = chunk['data']['outputs'][self.workflow_output_key]
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=result)
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.conversation_ids.pop(session_id, None)
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("Dify 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 Dify 的历史消息记录。")
|
||||
|
||||
async def terminate(self):
|
||||
await self.api_client.close()
|
||||
@@ -0,0 +1,287 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
import aiohttp
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
class SimpleGoogleGenAIClient():
|
||||
def __init__(self, api_key: str, api_base: str):
|
||||
self.api_key = api_key
|
||||
if api_base.endswith("/"):
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession()
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
async with self.client.get(request_url, timeout=10) as resp:
|
||||
response = await resp.json()
|
||||
|
||||
models = []
|
||||
for model in response["models"]:
|
||||
if 'generateContent' in model["supportedGenerationMethods"]:
|
||||
models.append(model["name"].replace("models/", ""))
|
||||
return models
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str="gemini-1.5-flash",
|
||||
system_instruction: str="",
|
||||
tools: dict=None
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {
|
||||
"parts": {"text": system_instruction}
|
||||
}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
async with self.client.post(request_url, json=payload, timeout=10) as resp:
|
||||
response = await resp.json()
|
||||
return response
|
||||
|
||||
|
||||
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
|
||||
class ProviderGoogleGenAI(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
|
||||
self.client = SimpleGoogleGenAIClient(
|
||||
api_key=self.chosen_api_key,
|
||||
api_base=provider_config.get("api_base", None)
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
total_pages = len(contexts) // page_size
|
||||
if len(contexts) % page_size != 0:
|
||||
total_pages += 1
|
||||
|
||||
return paged_contexts, total_pages
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
|
||||
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
|
||||
'''
|
||||
弹出第一条记录
|
||||
'''
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
|
||||
if len(self.session_memory[session_id]) == 0:
|
||||
return None
|
||||
|
||||
for i in range(len(self.session_memory[session_id])):
|
||||
# 检查是否是 system prompt
|
||||
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
|
||||
# 如果只有一个 system prompt,才不删掉
|
||||
f = False
|
||||
for j in range(i+1, len(self.session_memory[session_id])):
|
||||
if self.session_memory[session_id][j]['user']['role'] == "system":
|
||||
f = True
|
||||
break
|
||||
if not f:
|
||||
continue
|
||||
record = self.session_memory[session_id].pop(i)
|
||||
break
|
||||
|
||||
return record
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
if tools:
|
||||
tool = tools.get_func_desc_google_genai_style()
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "system":
|
||||
system_instruction = message["content"]
|
||||
break
|
||||
|
||||
google_genai_conversation = []
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
google_genai_conversation.append({
|
||||
"role": "user",
|
||||
"parts": [{"text": message["content"]}]
|
||||
})
|
||||
elif isinstance(message["content"], list):
|
||||
# images
|
||||
parts = []
|
||||
for part in message["content"]:
|
||||
if part["type"] == "text":
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
parts.append({"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64
|
||||
}})
|
||||
google_genai_conversation.append({
|
||||
"role": "user",
|
||||
"parts": parts
|
||||
})
|
||||
|
||||
elif message["role"] == "assistant":
|
||||
google_genai_conversation.append({
|
||||
"role": "model",
|
||||
"parts": [{"text": message["content"]}]
|
||||
})
|
||||
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
candidates = result["candidates"][0]['content']['parts']
|
||||
llm_response = LLMResponse("assistant")
|
||||
for candidate in candidates:
|
||||
if 'text' in candidate:
|
||||
llm_response.completion_text += candidate['text']
|
||||
elif 'functionCall' in candidate:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args.append(candidate['functionCall']['args'])
|
||||
llm_response.tools_call_name.append(candidate['functionCall']['name'])
|
||||
|
||||
return llm_response
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**self.provider_config.get("model_config", {})
|
||||
}
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
# 添加用户 record
|
||||
self.session_memory[session_id].append(new_record)
|
||||
# 添加 assistant record
|
||||
self.session_memory[session_id].append({
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
})
|
||||
else:
|
||||
self.session_memory[session_id] = [*contexts, new_record, {
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
'''
|
||||
组装上下文。
|
||||
'''
|
||||
if image_urls:
|
||||
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user","content": text}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
'''
|
||||
将图片转换为 base64
|
||||
'''
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ''
|
||||
@@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
# 文本回复
|
||||
if not contexts:
|
||||
@@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
}]
|
||||
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
|
||||
|
||||
return llm_response
|
||||
|
||||
async def forget(self, session_id: str) -> bool:
|
||||
self.session_memory[session_id] = []
|
||||
return True
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import traceback
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
|
||||
class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history = True
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str,
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
else:
|
||||
context_query = [*contexts, new_record]
|
||||
|
||||
model_cfgs: dict = self.provider_config.get("model_config", {})
|
||||
# glm-4v-flash 只支持一张图片
|
||||
model: str = model_cfgs.get("model", "")
|
||||
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
|
||||
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
|
||||
logger.debug(context_query)
|
||||
new_context_query_ = []
|
||||
for i in range(0, len(context_query) - 1, 2):
|
||||
if isinstance(context_query[i].get("content", ""), list):
|
||||
continue
|
||||
new_context_query_.append(context_query[i])
|
||||
new_context_query_.append(context_query[i+1])
|
||||
new_context_query_.append(context_query[-1]) # 保留最后一条记录
|
||||
context_query = new_context_query_
|
||||
logger.debug(context_query)
|
||||
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
payloads = {
|
||||
"messages": context_query,
|
||||
**model_cfgs
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import List
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
class SimpleOpenAIEmbedding():
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
api_key,
|
||||
api_base=None,
|
||||
) -> None:
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
)
|
||||
self.model = model
|
||||
|
||||
async def get_embedding(self, text) -> List[float]:
|
||||
'''
|
||||
获取文本的嵌入
|
||||
'''
|
||||
embedding = await self.client.embeddings.create(
|
||||
input=text,
|
||||
model=self.model
|
||||
)
|
||||
return embedding.data[0].embedding
|
||||
@@ -0,0 +1,92 @@
|
||||
import os
|
||||
from typing import List, Dict
|
||||
from astrbot.core import logger
|
||||
from .store import Store
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
class KnowledgeDBManager():
|
||||
def __init__(self, astrbot_config: AstrBotConfig) -> None:
|
||||
self.db_path = "data/knowledge_db/"
|
||||
self.config = astrbot_config.get("knowledge_db", {})
|
||||
self.astrbot_config = astrbot_config
|
||||
if not os.path.exists(self.db_path):
|
||||
os.makedirs(self.db_path)
|
||||
self.store_insts: Dict[str, Store] = {}
|
||||
for name, cfg in self.config.items():
|
||||
if cfg["strategy"] == "embedding":
|
||||
logger.info(f"加载 Chroma Vector Store:{name}")
|
||||
try:
|
||||
from .store.chroma_db import ChromaVectorStore
|
||||
except ImportError as ie:
|
||||
logger.error(f"{ie} 可能未安装 chromadb 库。")
|
||||
continue
|
||||
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
|
||||
else:
|
||||
logger.error(f"不支持的策略:{cfg['strategy']}")
|
||||
|
||||
|
||||
async def list_knowledge_db(self) -> List[str]:
|
||||
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
|
||||
|
||||
|
||||
async def create_knowledge_db(self, name: str, config: Dict):
|
||||
'''
|
||||
config 格式:
|
||||
```
|
||||
{
|
||||
"strategy": "embedding", # 目前只支持 embedding
|
||||
"chunk_method": {
|
||||
"strategy": "fixed",
|
||||
"chunk_size": 100,
|
||||
"overlap_size": 10
|
||||
},
|
||||
"embedding_config": {
|
||||
"strategy": "openai",
|
||||
"base_url": "",
|
||||
"model": "",
|
||||
"api_key": ""
|
||||
}
|
||||
}
|
||||
```
|
||||
'''
|
||||
if name in self.config:
|
||||
raise ValueError(f"知识库已存在:{name}")
|
||||
|
||||
self.config[name] = config
|
||||
self.astrbot_config["knowledge_db"] = self.config
|
||||
self.astrbot_config.save_config()
|
||||
|
||||
|
||||
async def insert_record(self, name: str, text: str):
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
ret = []
|
||||
match self.config[name]["chunk_method"]['strategy']:
|
||||
case "fixed":
|
||||
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
|
||||
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
|
||||
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
|
||||
case _:
|
||||
pass
|
||||
|
||||
for chunk in ret:
|
||||
await self.store_insts[name].save(chunk)
|
||||
|
||||
|
||||
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
|
||||
if name not in self.store_insts:
|
||||
raise ValueError(f"未找到知识库:{name}")
|
||||
|
||||
inst = self.store_insts[name]
|
||||
return await inst.query(query, top_n)
|
||||
|
||||
|
||||
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
|
||||
chunks = []
|
||||
start = 0
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
chunks.append(text[start:end])
|
||||
start += chunk_size - chunk_overlap
|
||||
return chunks
|
||||
@@ -0,0 +1,8 @@
|
||||
from typing import List
|
||||
|
||||
class Store():
|
||||
async def save(self, text: str):
|
||||
pass
|
||||
|
||||
async def query(self, query: str, top_n: int = 3) -> List[str]:
|
||||
pass
|
||||
@@ -0,0 +1,39 @@
|
||||
import chromadb
|
||||
import uuid
|
||||
from typing import List, Dict
|
||||
from astrbot.api import logger
|
||||
from ..embedding.openai_source import SimpleOpenAIEmbedding
|
||||
from . import Store
|
||||
|
||||
class ChromaVectorStore(Store):
|
||||
def __init__(self, name: str, embedding_cfg: Dict) -> None:
|
||||
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
|
||||
self.collection = self.chroma_client.get_or_create_collection(name=name)
|
||||
self.embedding = None
|
||||
if embedding_cfg["strategy"] == "openai":
|
||||
self.embedding = SimpleOpenAIEmbedding(
|
||||
model=embedding_cfg["model"],
|
||||
api_key=embedding_cfg["api_key"],
|
||||
api_base=embedding_cfg.get("base_url", None)
|
||||
)
|
||||
|
||||
async def save(self, text: str, metadata: Dict = None):
|
||||
logger.debug(f"Saving text: {text}")
|
||||
embedding = await self.embedding.get_embedding(text)
|
||||
|
||||
self.collection.upsert(
|
||||
documents=text,
|
||||
metadatas=metadata,
|
||||
ids=str(uuid.uuid4()),
|
||||
embeddings=embedding
|
||||
)
|
||||
|
||||
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
|
||||
embedding = await self.embedding.get_embedding(query)
|
||||
|
||||
results = self.collection.query(
|
||||
query_embeddings=embedding,
|
||||
n_results=top_n,
|
||||
where=metadata_filter
|
||||
)
|
||||
return results['documents'][0]
|
||||
@@ -1,6 +1,7 @@
|
||||
from asyncio import Queue
|
||||
from typing import List, TypedDict, Union
|
||||
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.provider.provider import Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -14,6 +15,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
class StarCommand(TypedDict):
|
||||
full_command_name: str
|
||||
@@ -38,11 +40,22 @@ class Context:
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
|
||||
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
|
||||
def __init__(self,
|
||||
event_queue: Queue,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
knowledge_db_manager: KnowledgeDBManager = None
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
self._db = db
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.knowledge_db_manager = knowledge_db_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
for star in star_registry:
|
||||
@@ -73,7 +86,7 @@ class Context:
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_str=func_obj.__module__,
|
||||
handler_module_path=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
@@ -94,6 +107,12 @@ class Context:
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -105,6 +124,12 @@ class Context:
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -125,7 +150,7 @@ class Context:
|
||||
event_type=EventType.AdapterMessageEvent,
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_str=awaitable.__module__,
|
||||
handler_module_path=awaitable.__module__,
|
||||
handler=awaitable,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
|
||||
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
|
||||
ls = re.split(r"\s+", message_str)
|
||||
if self.command_name != ls[0]:
|
||||
return False
|
||||
# if len(self.handler_params) == 0 and len(ls) > 1:
|
||||
# # 一定程度避免 LLM 聊天时误判为指令
|
||||
# return False
|
||||
# params_str = message_str[len(self.command_name):].strip()
|
||||
ls = ls[1:]
|
||||
# 去除空字符串
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
|
||||
event_type=event_type,
|
||||
handler_full_name=handler_full_name,
|
||||
handler_name=handler.__name__,
|
||||
handler_module_str=handler.__module__,
|
||||
handler_module_path=handler.__module__,
|
||||
handler=handler,
|
||||
event_filters=[]
|
||||
)
|
||||
@@ -185,7 +185,7 @@ def register_llm_tool(name: str = None):
|
||||
"description": arg.description
|
||||
})
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return awaitable
|
||||
|
||||
@@ -32,6 +32,9 @@ class StarMetadata:
|
||||
'''Star 的根目录名'''
|
||||
reserved: bool = False
|
||||
'''是否是 AstrBot 的保留 Star'''
|
||||
|
||||
activated: bool = True
|
||||
'''是否被激活'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
@@ -1,11 +1,12 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, List, Dict
|
||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
||||
from .filter import HandlerFilter
|
||||
from .star import star_map
|
||||
|
||||
|
||||
class StarHandlerRegistry(List):
|
||||
T = TypeVar('T', bound='StarHandlerMetadata')
|
||||
class StarHandlerRegistry(Generic[T], List[T]):
|
||||
'''用于存储所有的 Star Handler'''
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
@@ -16,9 +17,18 @@ class StarHandlerRegistry(List):
|
||||
super().append(handler)
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
|
||||
def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]:
|
||||
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
|
||||
'''通过事件类型获取 Handler'''
|
||||
return [handler for handler in self if handler.event_type == event_type]
|
||||
if only_activated:
|
||||
return [
|
||||
handler
|
||||
for handler in self
|
||||
if handler.event_type == event_type and
|
||||
star_map[handler.handler_module_path] and
|
||||
star_map[handler.handler_module_path].activated
|
||||
]
|
||||
else:
|
||||
return [handler for handler in self if handler.event_type == event_type]
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
'''通过 Handler 的全名获取 Handler'''
|
||||
@@ -26,8 +36,7 @@ class StarHandlerRegistry(List):
|
||||
|
||||
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
|
||||
'''通过模块名获取 Handler'''
|
||||
return [handler for handler in self if handler.handler_module_str == module_name]
|
||||
|
||||
return [handler for handler in self if handler.handler_module_path == module_name]
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
@@ -55,7 +64,7 @@ class StarHandlerMetadata():
|
||||
handler_name: str
|
||||
'''Handler 的名字,也就是方法名'''
|
||||
|
||||
handler_module_str: str
|
||||
handler_module_path: str
|
||||
'''Handler 所在的模块路径。'''
|
||||
|
||||
handler: Awaitable
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import inspect
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import yaml
|
||||
import logging
|
||||
@@ -8,15 +9,14 @@ from types import ModuleType
|
||||
from typing import List
|
||||
from pip import main as pip_main
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, sp
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
from .star import star_registry, star_map
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
from .star_handler import star_handlers_registry
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
|
||||
class PluginManager:
|
||||
def __init__(
|
||||
@@ -27,6 +27,7 @@ class PluginManager:
|
||||
self.updator = PluginUpdator(config['plugin_repo_mirror'])
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self # 就这样吧,不想改了
|
||||
|
||||
self.config = config
|
||||
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
|
||||
@@ -101,7 +102,7 @@ class PluginManager:
|
||||
'''更新插件的依赖'''
|
||||
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
|
||||
if self.config.pip_install_arg:
|
||||
args.extend(self.config.pip_install_arg)
|
||||
args.extend([self.config.pip_install_arg])
|
||||
result_code = pip_main(args)
|
||||
if result_code != 0:
|
||||
raise Exception(str(result_code))
|
||||
@@ -136,15 +137,29 @@ class PluginManager:
|
||||
|
||||
return metadata
|
||||
|
||||
def reload(self):
|
||||
async def reload(self):
|
||||
'''扫描并加载所有的 Star'''
|
||||
for smd in star_registry:
|
||||
logger.debug(f"尝试终止插件 {smd.name} ...")
|
||||
if hasattr(smd.star_cls, "__del__"):
|
||||
smd.star_cls.__del__()
|
||||
|
||||
star_handlers_registry.clear()
|
||||
star_handlers_registry.star_handlers_map.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
|
||||
plugin_modules = self._get_plugin_modules()
|
||||
if plugin_modules is None:
|
||||
return False, "未找到任何插件模块"
|
||||
fail_rec = ""
|
||||
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
# 导入 Star 模块,并尝试实例化 Star 类
|
||||
for plugin_module in plugin_modules:
|
||||
try:
|
||||
@@ -171,21 +186,24 @@ class PluginManager:
|
||||
|
||||
if path in star_map:
|
||||
# 通过装饰器的方式注册插件
|
||||
star_metadata = star_map[path]
|
||||
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
|
||||
star_metadata.module = module
|
||||
star_metadata.root_dir_name = root_dir_name
|
||||
star_metadata.reserved = reserved
|
||||
metadata = star_map[path]
|
||||
metadata.star_cls = metadata.star_cls_type(context=self.context)
|
||||
metadata.module = module
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
related_handlers = star_handlers_registry.get_handlers_by_module_name(star_metadata.module_path)
|
||||
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
|
||||
for handler in related_handlers:
|
||||
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
|
||||
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
|
||||
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
|
||||
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
|
||||
handler.handler = functools.partial(handler.handler, metadata.star_cls)
|
||||
# llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler.__module__ == star_metadata.module_path:
|
||||
func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls)
|
||||
if func_tool.handler.__module__ == metadata.module_path:
|
||||
func_tool.handler_module_path = metadata.module_path
|
||||
func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls)
|
||||
if func_tool.name in inactivated_llm_tools:
|
||||
func_tool.active = False
|
||||
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
@@ -209,6 +227,13 @@ class PluginManager:
|
||||
star_map[path] = metadata
|
||||
star_registry.append(metadata)
|
||||
logger.debug(f"插件 {root_dir_name} 载入成功。")
|
||||
|
||||
if metadata.module_path in inactivated_plugins:
|
||||
metadata.activated = False
|
||||
|
||||
# 执行 initialize 函数
|
||||
if hasattr(metadata.star_cls, "initialize"):
|
||||
await metadata.star_cls.initialize()
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
@@ -225,10 +250,11 @@ class PluginManager:
|
||||
|
||||
async def install_plugin(self, repo_url: str):
|
||||
plugin_path = await self.updator.install(repo_url)
|
||||
self._check_plugin_dept_update()
|
||||
# reload the plugin
|
||||
await self.reload()
|
||||
return plugin_path
|
||||
|
||||
def uninstall_plugin(self, plugin_name: str):
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
@@ -237,10 +263,26 @@ class PluginManager:
|
||||
root_dir_name = plugin.root_dir_name
|
||||
ppath = self.plugin_store_path
|
||||
|
||||
del star_map[plugin.module_path]
|
||||
# 从 star_registry 和 star_map 中删除
|
||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||
|
||||
if not remove_dir(os.path.join(ppath, root_dir_name)):
|
||||
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
|
||||
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
del star_map[plugin_module_path]
|
||||
for i, p in enumerate(star_registry):
|
||||
if p.name == plugin_name:
|
||||
del star_registry[i]
|
||||
break
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path):
|
||||
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
|
||||
star_handlers_registry.remove(handler)
|
||||
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)]
|
||||
for k in keys_to_delete:
|
||||
v = star_handlers_registry.star_handlers_map[k]
|
||||
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
async def update_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
@@ -250,6 +292,46 @@ class PluginManager:
|
||||
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
|
||||
|
||||
await self.updator.update(plugin)
|
||||
await self.reload()
|
||||
|
||||
async def turn_off_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
if plugin.module_path not in inactivated_plugins:
|
||||
inactivated_plugins.append(plugin.module_path)
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
|
||||
# 禁用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
func_tool.active = False
|
||||
inactivated_llm_tools.append(func_tool.name)
|
||||
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = False
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if plugin.module_path in inactivated_plugins:
|
||||
inactivated_plugins.remove(plugin.module_path)
|
||||
sp.put("inactivated_plugins", inactivated_plugins)
|
||||
|
||||
# 启用插件启用的 llm_tool
|
||||
for func_tool in llm_tools.func_list:
|
||||
if func_tool.handler_module_path == plugin.module_path:
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
func_tool.active = True
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
plugin.activated = True
|
||||
|
||||
|
||||
def install_plugin_from_file(self, zip_file_path: str):
|
||||
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
|
||||
@@ -262,3 +344,4 @@ class PluginManager:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
|
||||
self._check_plugin_dept_update()
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
|
||||
|
||||
files = os.listdir(os.path.join(target_dir, update_dir))
|
||||
for f in files:
|
||||
logger.info(f"移动更新文件/目录: {f}")
|
||||
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
|
||||
if os.path.exists(os.path.join(target_dir, f)):
|
||||
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
|
||||
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
|
||||
|
||||
try:
|
||||
logger.info(f"删除临时更新文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
logger.info(f"删除临时文件: {zip_path} 和 {os.path.join(target_dir, update_dir)}")
|
||||
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
|
||||
os.remove(zip_path)
|
||||
except BaseException:
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
|
||||
class DifyAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = ClientSession()
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
inputs: Dict,
|
||||
query: str,
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
conversation_id: str = "",
|
||||
files: List[Dict[str, Any]] = [],
|
||||
timeout: float = 60,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
url = f"{self.api_base}/chat-messages"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
inputs: Dict,
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
files: List[Dict[str, Any]] = [],
|
||||
timeout: float = 60,
|
||||
):
|
||||
url = f"{self.api_base}/workflows/run"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
file_path: str,
|
||||
user: str,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.api_base}/files/upload"
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": open(file_path, "rb"),
|
||||
}
|
||||
async with self.session.post(
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
@@ -5,6 +5,7 @@ import socket
|
||||
import time
|
||||
import aiohttp
|
||||
import base64
|
||||
import zipfile
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -96,7 +97,9 @@ async def download_file(url: str, path: str):
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
async with session.get(url, timeout=20) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
@@ -123,3 +126,10 @@ def get_local_ip_addresses():
|
||||
finally:
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
async def download_dashboard():
|
||||
'''下载管理面板文件'''
|
||||
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
@@ -22,6 +22,9 @@ class ParameterValidationMixin:
|
||||
result[param_name] = int(params[i])
|
||||
else:
|
||||
result[param_name] = params[i]
|
||||
elif isinstance(param_type_or_default_val, str):
|
||||
# 如果 param_type_or_default_val 是字符串,直接赋值
|
||||
result[param_name] = params[i]
|
||||
else:
|
||||
result[param_name] = param_type_or_default_val(params[i])
|
||||
except ValueError:
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
class SharedPreferences:
|
||||
def __init__(self, path="data/shared_preferences.json"):
|
||||
self.path = path
|
||||
self._data = self._load_preferences()
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._data.get(key, default)
|
||||
|
||||
def put(self, key, value):
|
||||
self._data[key] = value
|
||||
self._save_preferences()
|
||||
|
||||
def remove(self, key):
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
self._save_preferences()
|
||||
|
||||
def clear(self):
|
||||
self._data.clear()
|
||||
self._save_preferences()
|
||||
@@ -111,7 +111,7 @@ class RepoZipUpdator():
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,将从默认分支下载。")
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
@@ -7,6 +7,7 @@ from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.star.config import update_config
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.platform.register import platform_registry
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
@@ -121,6 +122,12 @@ class ConfigRoute(Route):
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
|
||||
for platform in platform_registry:
|
||||
if platform.default_config_tmpl:
|
||||
platform_default_tmpl[platform.name] = platform.default_config_tmpl
|
||||
|
||||
return {
|
||||
"metadata": CONFIG_METADATA_2,
|
||||
"config": config
|
||||
|
||||
@@ -16,7 +16,9 @@ class PluginRoute(Route):
|
||||
'/plugin/install-upload': ('POST', self.install_plugin_upload),
|
||||
'/plugin/update': ('POST', self.update_plugin),
|
||||
'/plugin/uninstall': ('POST', self.uninstall_plugin),
|
||||
'/plugin/market_list': ('GET', self.get_online_plugins)
|
||||
'/plugin/market_list': ('GET', self.get_online_plugins),
|
||||
'/plugin/off': ('POST', self.off_plugin),
|
||||
'/plugin/on': ('POST', self.on_plugin)
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
@@ -42,7 +44,8 @@ class PluginRoute(Route):
|
||||
"author": plugin.author,
|
||||
"desc": plugin.desc,
|
||||
"version": plugin.version,
|
||||
"reserved": plugin.reserved
|
||||
"reserved": plugin.reserved,
|
||||
"activated": plugin.activated
|
||||
}
|
||||
_plugin_resp.append(_t)
|
||||
return Response().ok(_plugin_resp).__dict__
|
||||
@@ -53,7 +56,6 @@ class PluginRoute(Route):
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {repo_url} 成功。")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
@@ -69,7 +71,6 @@ class PluginRoute(Route):
|
||||
await file.save(file_path)
|
||||
self.plugin_manager.install_plugin_from_file(file_path)
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
self.core_lifecycle.restart()
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
@@ -80,7 +81,7 @@ class PluginRoute(Route):
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
logger.info(f"正在卸载插件 {plugin_name}")
|
||||
self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
await self.plugin_manager.uninstall_plugin(plugin_name)
|
||||
logger.info(f"卸载插件 {plugin_name} 成功")
|
||||
return Response().ok(None, "卸载成功").__dict__
|
||||
except Exception as e:
|
||||
@@ -93,9 +94,30 @@ class PluginRoute(Route):
|
||||
try:
|
||||
logger.info(f"正在更新插件 {plugin_name}")
|
||||
await self.plugin_manager.update_plugin(plugin_name)
|
||||
self.core_lifecycle.restart()
|
||||
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
|
||||
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
|
||||
logger.info(f"更新插件 {plugin_name} 成功。")
|
||||
return Response().ok(None, "更新成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def off_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_off_plugin(plugin_name)
|
||||
logger.info(f"停用插件 {plugin_name} 。")
|
||||
return Response().ok(None, "停用成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/off: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def on_plugin(self):
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
await self.plugin_manager.turn_on_plugin(plugin_name)
|
||||
logger.info(f"启用插件 {plugin_name} 。")
|
||||
return Response().ok(None, "启用成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/extensions/on: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -32,6 +32,7 @@ class UpdateRoute(Route):
|
||||
async def update_project(self):
|
||||
data = await request.json
|
||||
version = data.get('version', '')
|
||||
reboot = data.get('reboot', True)
|
||||
if version == "" or version == "latest":
|
||||
latest = True
|
||||
version = ''
|
||||
@@ -39,8 +40,11 @@ class UpdateRoute(Route):
|
||||
latest = False
|
||||
try:
|
||||
await self.astrbot_updator.update(latest=latest, version=version)
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
if reboot:
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
else:
|
||||
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -18,7 +18,6 @@ class AstrBotDashboard():
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.config = core_lifecycle.astrbot_config
|
||||
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
|
||||
logger.info(f"Dashboard data path: {self.data_path}")
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
self.app.json.sort_keys = False
|
||||
self.app.before_request(self.auth_middleware)
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# What's Changed
|
||||
|
||||
1. 修复了 reminder 插件可能不会触发回调的问题。
|
||||
2. 修复了 telegram 插件不可用的问题。
|
||||
3. 修复了 qq_official 无法发图的问题。
|
||||
4. 修复事件监听器会让 WakeStage 失效的问题。
|
||||
5. 修复 websearch 在 cmd_config 中失效的问题。
|
||||
3. 支持通过 Google GenAI 访问 Gemini 模型,而不需要使用 Gemini 对 OpenAI 的兼容 API。详见文档。
|
||||
4. 支持对插件禁用/启用。/plugin off/on <plugin_name>
|
||||
5. 支持基于 Docker 的沙箱化代码执行器。(Beta 测试)详见文档。
|
||||
6. 支持接入 Dify LLMOps 平台。详见文档。
|
||||
7. 适配器类插件支持设置默认配置模板。
|
||||
8. 优化了部分指令的持久化记忆。如 /tool 的禁用、/provider 的选择都将持久化保存,每次启动时不需要重新设置。
|
||||
9. 优化了 glm-4v-flash 模型。其只支持一张图。
|
||||
@@ -1,4 +1,2 @@
|
||||
node_modules/
|
||||
.DS_Store
|
||||
package-lock.json
|
||||
package.json
|
||||
.DS_Store
|
||||
Generated
+9981
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,59 @@
|
||||
{
|
||||
"name": "astrbot-dashboard",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"author": "CodedThemes",
|
||||
"scripts": {
|
||||
"dev": "vite --host",
|
||||
"build": "vue-tsc --noEmit && vite build",
|
||||
"build-stage": "vue-tsc --noEmit && vite build --base=/vue/free/stage/",
|
||||
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
|
||||
"preview": "vite preview --port 5050",
|
||||
"typecheck": "vue-tsc --noEmit",
|
||||
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
|
||||
},
|
||||
"dependencies": {
|
||||
"@guolao/vue-monaco-editor": "^1.5.4",
|
||||
"@tiptap/starter-kit": "2.1.7",
|
||||
"@tiptap/vue-3": "2.1.7",
|
||||
"apexcharts": "3.42.0",
|
||||
"axios": "^1.6.2",
|
||||
"axios-mock-adapter": "^1.22.0",
|
||||
"chance": "1.1.11",
|
||||
"date-fns": "2.30.0",
|
||||
"js-md5": "^0.8.3",
|
||||
"lodash": "4.17.21",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
"vite-plugin-vuetify": "1.0.2",
|
||||
"vue": "3.3.4",
|
||||
"vue-router": "4.2.4",
|
||||
"vue3-apexcharts": "1.4.4",
|
||||
"vue3-print-nb": "0.1.4",
|
||||
"vuetify": "3.3.14",
|
||||
"xterm": "^5.3.0",
|
||||
"xterm-addon-fit": "^0.8.0",
|
||||
"yup": "1.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@mdi/font": "7.2.96",
|
||||
"@rushstack/eslint-patch": "1.3.3",
|
||||
"@types/chance": "1.1.3",
|
||||
"@types/node": "20.5.7",
|
||||
"@vitejs/plugin-vue": "4.3.3",
|
||||
"@vue/eslint-config-prettier": "8.0.0",
|
||||
"@vue/eslint-config-typescript": "11.0.3",
|
||||
"@vue/tsconfig": "0.4.0",
|
||||
"eslint": "8.48.0",
|
||||
"eslint-plugin-vue": "9.17.0",
|
||||
"prettier": "3.0.2",
|
||||
"sass": "1.66.1",
|
||||
"sass-loader": "13.3.2",
|
||||
"typescript": "5.1.6",
|
||||
"vite": "4.4.9",
|
||||
"vue-cli-plugin-vuetify": "2.5.8",
|
||||
"vue-tsc": "1.8.8",
|
||||
"vuetify-loader": "^2.0.0-alpha.9"
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@
|
||||
</v-alert>
|
||||
|
||||
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
|
||||
<div style="width: 100%;">
|
||||
<div style="width: 100%;" v-if="metadata[metadataKey].items[key]">
|
||||
<v-select v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
|
||||
variant="outlined" :items="metadata[metadataKey].items[key]?.options"
|
||||
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense :disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
|
||||
@@ -46,6 +46,11 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style="width: 100%;" v-else>
|
||||
<!-- 在 metadata 中没有 key -->
|
||||
<v-text-field v-model="iterable[key]" :label="key" variant="outlined" dense></v-text-field>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object' && !metadata[metadataKey].items[key]?.invisible">
|
||||
<v-btn icon size="x-small" style="margin-bottom: 22px;">
|
||||
|
||||
@@ -29,7 +29,9 @@ import axios from 'axios';
|
||||
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
|
||||
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
|
||||
</div>
|
||||
<span v-else>保留插件</span>
|
||||
<!-- <span v-else>保留插件</span> -->
|
||||
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
|
||||
<v-btn variant="plain" v-else @click="pluginOn(extension)">启用</v-btn>
|
||||
</div>
|
||||
</ExtensionCard>
|
||||
</v-col>
|
||||
@@ -329,6 +331,36 @@ export default {
|
||||
this.toast(err, "error");
|
||||
});
|
||||
},
|
||||
pluginOn(extension) {
|
||||
axios.post('/api/plugin/on',
|
||||
{
|
||||
name: extension.name
|
||||
}).then((res) => {
|
||||
if (res.data.status === "error") {
|
||||
this.toast(res.data.message, "error");
|
||||
return;
|
||||
}
|
||||
this.toast(res.data.message, "success");
|
||||
this.getExtensions();
|
||||
}).catch((err) => {
|
||||
this.toast(err, "error");
|
||||
});
|
||||
},
|
||||
pluginOff(extension) {
|
||||
axios.post('/api/plugin/off',
|
||||
{
|
||||
name: extension.name
|
||||
}).then((res) => {
|
||||
if (res.data.status === "error") {
|
||||
this.toast(res.data.message, "error");
|
||||
return;
|
||||
}
|
||||
this.toast(res.data.message, "success");
|
||||
this.getExtensions();
|
||||
}).catch((err) => {
|
||||
this.toast(err, "error");
|
||||
});
|
||||
},
|
||||
openExtensionConfig(extension_name) {
|
||||
this.curr_namespace = extension_name;
|
||||
this.configDialog = true;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import sys
|
||||
@@ -8,6 +7,8 @@ import zipfile
|
||||
from astrbot.dashboard import AstrBotDashBoardLifecycle
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core import logger, LogManager, LogBroker
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
|
||||
# add parent path to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -39,25 +40,20 @@ def check_env():
|
||||
async def check_dashboard_files():
|
||||
'''下载管理面板文件'''
|
||||
if os.path.exists("data/dist"):
|
||||
return
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
logger.info("开始下载管理面板文件...")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(dashboard_release_url) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"下载管理面板文件失败: {resp.status}")
|
||||
with open("data/dashboard.zip", "wb") as f:
|
||||
f.write(await resp.read())
|
||||
logger.info("管理面板文件下载完成。")
|
||||
ok = True
|
||||
|
||||
if not ok:
|
||||
logger.critical("下载管理面板文件失败")
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
if f.read() != VERSION:
|
||||
logger.warning("检测到管理面板有更新。可以使用 /dashboard update 命令更新。")
|
||||
return
|
||||
|
||||
# unzip
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
logger.info("开始下载管理面板文件...")
|
||||
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.critical(f"下载管理面板文件失败: {e}")
|
||||
return
|
||||
|
||||
logger.info("管理面板下载完成。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
+77
-16
@@ -3,8 +3,9 @@ import datetime
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import personalities
|
||||
from astrbot.api import personalities, sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
|
||||
from typing import Union
|
||||
|
||||
@@ -16,6 +17,8 @@ class Main(star.Star):
|
||||
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
|
||||
self.identifier = cfg['provider_settings']['identifier']
|
||||
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
|
||||
|
||||
self.kdb_enabled = False
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
@@ -42,6 +45,7 @@ class Main(star.Star):
|
||||
/deop <admin_id>: 取消管理员
|
||||
/wl <sid>: 添加会话白名单
|
||||
/dwl <sid>: 删除会话白名单
|
||||
/dashboard update: 更新管理面板
|
||||
|
||||
[大模型]
|
||||
/provider: 查看、切换大模型提供商
|
||||
@@ -87,24 +91,41 @@ class Main(star.Star):
|
||||
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
|
||||
|
||||
@filter.command("plugin")
|
||||
async def plugin(self, event: AstrMessageEvent, oper: str = None):
|
||||
if oper is None:
|
||||
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
|
||||
if oper1 is None:
|
||||
plugin_list_info = "已加载的插件:\n"
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n"
|
||||
if plugin_list_info.strip() == "":
|
||||
plugin_list_info = "没有加载任何插件。"
|
||||
|
||||
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。"
|
||||
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
|
||||
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
|
||||
else:
|
||||
plugin = self.context.get_registered_star(oper)
|
||||
if plugin is None:
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
if oper1 == "off":
|
||||
# 禁用插件
|
||||
if oper2 is None:
|
||||
event.set_result(MessageEventResult().message("/plugin off <插件名> 禁用插件。"))
|
||||
return
|
||||
await self.context._star_manager.turn_off_plugin(oper2)
|
||||
event.set_result(MessageEventResult().message(f"插件 {oper2} 已禁用。"))
|
||||
elif oper1 == "on":
|
||||
# 启用插件
|
||||
if oper2 is None:
|
||||
event.set_result(MessageEventResult().message("/plugin on <插件名> 启用插件。"))
|
||||
return
|
||||
await self.context._star_manager.turn_on_plugin(oper2)
|
||||
event.set_result(MessageEventResult().message(f"插件 {oper2} 已启用。"))
|
||||
|
||||
else:
|
||||
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
|
||||
ret = f"插件 {oper} 帮助信息:\n" + help_msg
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
# 获取插件帮助
|
||||
plugin = self.context.get_registered_star(oper1)
|
||||
if plugin is None:
|
||||
event.set_result(MessageEventResult().message("未找到此插件。"))
|
||||
else:
|
||||
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
|
||||
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@filter.command("t2i")
|
||||
async def t2i(self, event: AstrMessageEvent):
|
||||
@@ -167,8 +188,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
if idx is None:
|
||||
ret = "## 当前载入的 LLM 提供商\n"
|
||||
for idx, llm in enumerate(self.context.get_all_providers()):
|
||||
ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})"
|
||||
if self.provider == llm:
|
||||
id_ = llm.meta().id
|
||||
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
|
||||
if self.context.get_using_provider().meta().id == id_:
|
||||
ret += " (当前使用)"
|
||||
ret += "\n"
|
||||
|
||||
@@ -178,9 +200,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1]
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
self.context.provider_manager.curr_provider_inst = provider
|
||||
sp.put("curr_provider", id_)
|
||||
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}。"))
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
|
||||
@filter.command("reset")
|
||||
async def reset(self, message: AstrMessageEvent):
|
||||
@@ -289,7 +314,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
- 重置 LLM 会话(保留人格): /reset p
|
||||
|
||||
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
|
||||
"""))
|
||||
""").use_t2i(False))
|
||||
elif l[1] == "list":
|
||||
msg = "人格列表:\n"
|
||||
for key in personalities.keys():
|
||||
@@ -318,6 +343,13 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
name="自定义人格", prompt=ps)
|
||||
message.set_result(
|
||||
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard update")
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
yield event.plain_result("正在尝试更新管理面板...")
|
||||
await download_dashboard()
|
||||
yield event.plain_result("管理面板更新完成。")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
@@ -337,4 +369,33 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
|
||||
async def other_message(self, event: AstrMessageEvent):
|
||||
print("triggered")
|
||||
event.stop_event()
|
||||
event.stop_event()
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
pass
|
||||
|
||||
@kdb.command("on")
|
||||
async def on_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = True
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if not curr_kdb_name:
|
||||
yield event.plain_result("未载入任何知识库")
|
||||
else:
|
||||
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
|
||||
|
||||
@kdb.command("off")
|
||||
async def off_kdb(self, event: AstrMessageEvent):
|
||||
self.kdb_enabled = False
|
||||
yield event.plain_result("知识库已关闭")
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
|
||||
curr_kdb_name = self.context.provider_manager.curr_kdb_name
|
||||
if self.kdb_enabled and curr_kdb_name:
|
||||
mgr = self.context.knowledge_db_manager
|
||||
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
|
||||
if results:
|
||||
req.system_prompt += "\nHere are documents that related to user's query: \n"
|
||||
for result in results:
|
||||
req.system_prompt += f"- {result}\n"
|
||||
@@ -0,0 +1,382 @@
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import aiohttp
|
||||
import uuid
|
||||
import asyncio
|
||||
import re
|
||||
import astrbot.api.star as star
|
||||
import aiodocker
|
||||
from collections import defaultdict
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import llm_tool, logger
|
||||
from astrbot.api.event import filter
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.message_components import Image, File
|
||||
|
||||
PROMPT = """
|
||||
## Task
|
||||
You need to generate python codes to solve user's problem: {prompt}
|
||||
|
||||
{extra_input}
|
||||
|
||||
## Limit
|
||||
1. Available libraries:
|
||||
- standard libs
|
||||
- `Pillow`
|
||||
- `requests`
|
||||
- `numpy`
|
||||
- `matplotlib`
|
||||
- `scipy`
|
||||
- `scikit-learn`
|
||||
- `beautifulsoup4`
|
||||
- `pandas`
|
||||
- `opencv-python`
|
||||
- `python-docx`
|
||||
- `python-pptx`
|
||||
- `pymupdf` (Do not use fpdf, reportlab, etc.)
|
||||
- `mplfonts`
|
||||
You can only use these libraries and the libraries that they depend on.
|
||||
2. Do not generate malicious code.
|
||||
3. Use given `shared.api` package to output the result.
|
||||
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
|
||||
For Image and file, you must save it to `output` folder.
|
||||
4. You must only output the code, do not output the result of the code and any other information.
|
||||
5. The output language is same as user's input language.
|
||||
6. Please first provide relevant knowledge about user's problem appropriately.
|
||||
|
||||
## Example
|
||||
1. User's problem: `please solve the fabonacci sequence problem.`
|
||||
Output:
|
||||
```python
|
||||
from shared.api import send_text, send_image, send_file
|
||||
|
||||
def fabonacci(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return fabonacci(n-1) + fabonacci(n-2)
|
||||
|
||||
result = fabonacci(10)
|
||||
send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.")
|
||||
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
|
||||
```
|
||||
|
||||
2. User's problem: `please draw a sin(x) function.`
|
||||
Output:
|
||||
```python
|
||||
from shared.api import send_text, send_image, send_file
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
x = np.linspace(0, 2*np.pi, 100)
|
||||
y = np.sin(x)
|
||||
plt.plot(x, y)
|
||||
plt.savefig("output/sin_x.png")
|
||||
send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).")
|
||||
send_image("output/sin_x.png") # send_image is a function to send image to user
|
||||
send_text("If you need more information, please let me know :)")
|
||||
```
|
||||
|
||||
{extra_prompt}
|
||||
"""
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"sandbox": {
|
||||
"image": "soulter/astrbot-code-interpreter-sandbox",
|
||||
"docker_mirror": "", # cjie.eu.org
|
||||
}
|
||||
}
|
||||
PATH = "data/config/python_interpreter.json"
|
||||
|
||||
@star.register(name="astrbot-python-interpreter", desc="Python 代码执行器", author="Soulter", version="0.0.1")
|
||||
class Main(star.Star):
|
||||
'''基于 Docker 沙箱的 Python 代码执行器'''
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.workplace_path = os.path.join(self.curr_dir, "workplace")
|
||||
self.shared_path = os.path.join(self.curr_dir, "shared")
|
||||
os.makedirs(self.workplace_path, exist_ok=True)
|
||||
|
||||
self.user_file_msg_buffer = defaultdict(list)
|
||||
'''存放用户上传的文件'''
|
||||
|
||||
# 加载配置
|
||||
if not os.path.exists(PATH):
|
||||
self.config = DEFAULT_CONFIG
|
||||
self._save_config()
|
||||
else:
|
||||
with open(PATH, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
async def initialize(self):
|
||||
ok = await self.is_docker_available()
|
||||
if not ok:
|
||||
logger.warning("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。")
|
||||
await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter")
|
||||
|
||||
async def file_upload(self, file_path: str):
|
||||
'''
|
||||
上传图像文件到 S3
|
||||
'''
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
|
||||
with open(file_path, "rb") as f:
|
||||
file = f.read()
|
||||
|
||||
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
|
||||
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
|
||||
async with session.put(s3_file_url, data=file) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to upload image: {resp.status}")
|
||||
return s3_file_url
|
||||
|
||||
|
||||
async def is_docker_available(self) -> bool:
|
||||
'''Check if docker is available'''
|
||||
try:
|
||||
docker = aiodocker.Docker()
|
||||
await docker.version()
|
||||
return True
|
||||
except aiodocker.exceptions.DockerError as e:
|
||||
logger.error(f"检查 Docker 可用性时出现问题: {e}")
|
||||
return False
|
||||
|
||||
async def get_image_name(self) -> str:
|
||||
'''Get the image name'''
|
||||
if self.config["sandbox"]["docker_mirror"]:
|
||||
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
|
||||
return self.config["sandbox"]["image"]
|
||||
|
||||
async def _save_config(self):
|
||||
with open(PATH, "w") as f:
|
||||
json.dump(self.config, f)
|
||||
|
||||
async def gen_magic_code(self) -> str:
|
||||
return uuid.uuid4().hex[:8]
|
||||
|
||||
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
|
||||
'''Download image from url to workplace_path'''
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status != 200:
|
||||
return ""
|
||||
image_path = os.path.join(workplace_path, f"{filename}.jpg")
|
||||
with open(image_path, 'wb') as f:
|
||||
f.write(await resp.read())
|
||||
return f"{filename}.jpg"
|
||||
|
||||
async def tidy_code(self, code: str) -> str:
|
||||
'''Tidy the code'''
|
||||
pattern = r"```(?:py|python)?\n(.*?)\n```"
|
||||
match = re.search(pattern, code, re.DOTALL)
|
||||
if match is None:
|
||||
raise ValueError("The code is not in the code block.")
|
||||
return match.group(1)
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.ALL)
|
||||
async def on_message(self, event: AstrMessageEvent):
|
||||
'''处理消息'''
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, File):
|
||||
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
|
||||
logger.debug(f"User uploaded file: {comp.file}")
|
||||
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
|
||||
|
||||
@filter.on_llm_request()
|
||||
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
|
||||
if event.get_session_id() in self.user_file_msg_buffer:
|
||||
files = self.user_file_msg_buffer[event.get_session_id()]
|
||||
request.prompt += f"\nUser provided files: {files}"
|
||||
|
||||
|
||||
@filter.command_group("pi")
|
||||
def pi(self):
|
||||
pass
|
||||
|
||||
|
||||
@pi.command("mirror")
|
||||
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
|
||||
'''Docker 镜像地址'''
|
||||
if not url:
|
||||
yield event.plain_result(f"""当前 Docker 镜像地址: {self.config['sandbox']['docker_mirror']}。
|
||||
使用 `pi mirror <url>` 来设置 Docker 镜像地址。
|
||||
您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。
|
||||
""")
|
||||
else:
|
||||
self.config["sandbox"]["docker_mirror"] = url
|
||||
await self._save_config()
|
||||
yield event.plain_result("设置 Docker 镜像地址成功。")
|
||||
|
||||
@pi.command("repull")
|
||||
async def pi_repull(self, event: AstrMessageEvent):
|
||||
'''重新拉取沙箱镜像'''
|
||||
docker = aiodocker.Docker()
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
yield event.plain_result("重新拉取沙箱镜像成功。")
|
||||
|
||||
|
||||
@llm_tool("python_interpreter")
|
||||
async def python_interpreter(self, event: AstrMessageEvent):
|
||||
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
|
||||
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
|
||||
'''
|
||||
if not await self.is_docker_available():
|
||||
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
|
||||
|
||||
plain_text = event.message_str
|
||||
|
||||
# 创建必要的工作目录和幻术码
|
||||
magic_code = await self.gen_magic_code()
|
||||
workplace_path = os.path.join(self.workplace_path, magic_code)
|
||||
output_path = os.path.join(workplace_path, "output")
|
||||
os.makedirs(workplace_path, exist_ok=True)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# 图片
|
||||
images = []
|
||||
idx = 1
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
if image_url.startswith("http"):
|
||||
image_path = await self.download_image(image_url, workplace_path, f"img_{idx}")
|
||||
if image_path:
|
||||
images.append(image_path)
|
||||
idx += 1
|
||||
# 文件
|
||||
files = []
|
||||
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
|
||||
# cp
|
||||
file_name = os.path.basename(file_path)
|
||||
shutil.copy(file_path, os.path.join(workplace_path, file_name))
|
||||
files.append(file_name)
|
||||
|
||||
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
|
||||
|
||||
# 整理额外输入
|
||||
extra_inputs = ""
|
||||
if images:
|
||||
extra_inputs += f"User provided images: {images}\n"
|
||||
if files:
|
||||
extra_inputs += f"User provided files: {files}\n"
|
||||
|
||||
obs = ""
|
||||
n = 5
|
||||
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i+1}/{n}")
|
||||
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}")
|
||||
|
||||
logger.debug("code interpreter llm gened code:" + llm_response.completion_text)
|
||||
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
# 启动容器
|
||||
docker = aiodocker.Docker()
|
||||
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})")
|
||||
|
||||
container = await docker.containers.run({
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{self.shared_path}:/astrbot_sandbox/shared:ro",
|
||||
f"{output_path}:/astrbot_sandbox/output:rw",
|
||||
f"{workplace_path}:/astrbot_sandbox:rw",
|
||||
]
|
||||
},
|
||||
"Env": [
|
||||
f"MAGIC_CODE={magic_code}"
|
||||
],
|
||||
"AutoRemove": True
|
||||
})
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending file: {file_path}")
|
||||
file_s3_url = await self.file_upload(file_path)
|
||||
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain = [File(name=file_name, file=file_s3_url)]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif "Traceback (most recent call last)" in log \
|
||||
or "[Error]: " in log:
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。")
|
||||
|
||||
|
||||
async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]:
|
||||
'''Run the container and get the output'''
|
||||
try:
|
||||
await container.wait(timeout=timeout)
|
||||
logs = await container.log(stdout=True, stderr=True)
|
||||
return logs
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Container {container.id} timeout.")
|
||||
await container.kill()
|
||||
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
|
||||
finally:
|
||||
await container.delete()
|
||||
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
|
||||
def _get_magic_code():
|
||||
'''防止注入攻击'''
|
||||
return os.getenv("MAGIC_CODE")
|
||||
|
||||
def send_text(text: str):
|
||||
print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}")
|
||||
|
||||
def send_image(image_path: str):
|
||||
if not os.path.exists(image_path):
|
||||
raise Exception(f"Image file not found: {image_path}")
|
||||
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
|
||||
|
||||
def send_file(file_path: str):
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"File not found: {file_path}")
|
||||
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")
|
||||
@@ -31,9 +31,21 @@ class Main(star.Star):
|
||||
if "datetime" in reminder:
|
||||
if self.check_is_outdated(reminder):
|
||||
continue
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"], reminder], run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger='date',
|
||||
args=[reminder["text"], reminder],
|
||||
run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"),
|
||||
misfire_grace_time=60
|
||||
)
|
||||
elif "cron" in reminder:
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"], reminder], **self._parse_cron_expr(reminder["cron"]))
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger='cron',
|
||||
args=[reminder["text"], reminder],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(reminder["cron"])
|
||||
)
|
||||
|
||||
def check_is_outdated(self, reminder: dict):
|
||||
'''Check if the reminder is outdated.'''
|
||||
@@ -75,14 +87,25 @@ class Main(star.Star):
|
||||
if cron_expression:
|
||||
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d])
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
'cron',
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d]
|
||||
)
|
||||
if human_readable_cron:
|
||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||
else:
|
||||
d = { "text": text, "datetime": datetime_str }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], run_date=datetime_scheduled)
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
'date',
|
||||
args=[event.unified_msg_origin, d],
|
||||
run_date=datetime_scheduled,
|
||||
misfire_grace_time=60
|
||||
)
|
||||
reminder_time = datetime_str
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
|
||||
|
||||
@@ -22,6 +22,15 @@ class Main(star.Star):
|
||||
self.sogo_search = Sogo()
|
||||
self.google = Google()
|
||||
|
||||
async def initialize(self):
|
||||
websearch = self.context.get_config()['provider_settings']['web_search']
|
||||
if websearch:
|
||||
self.context.activate_llm_tool("web_search")
|
||||
self.context.activate_llm_tool("fetch_url")
|
||||
else:
|
||||
self.context.deactivate_llm_tool("web_search")
|
||||
self.context.deactivate_llm_tool("fetch_url")
|
||||
|
||||
async def _tidy_text(self, text: str) -> str:
|
||||
'''清理文本,去除空格、换行符等'''
|
||||
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
|
||||
|
||||
+2
-1
@@ -15,4 +15,5 @@ colorlog
|
||||
aiocqhttp
|
||||
pyjwt
|
||||
apscheduler
|
||||
docstring_parser
|
||||
docstring_parser
|
||||
aiodocker
|
||||
@@ -0,0 +1,148 @@
|
||||
import pytest
|
||||
import os
|
||||
from quart import Quart
|
||||
from astrbot.dashboard.server import AstrBotDashboard
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def core_lifecycle_td():
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
log_broker = LogBroker()
|
||||
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
|
||||
return core_lifecycle_td
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def app(core_lifecycle_td):
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
server = AstrBotDashboard(core_lifecycle_td, db)
|
||||
return server.app
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def header():
|
||||
return {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_core_lifecycle_td(core_lifecycle_td):
|
||||
await core_lifecycle_td.initialize()
|
||||
assert core_lifecycle_td is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.post('/api/auth/login', json={
|
||||
"username": "wrong",
|
||||
"password": "password"
|
||||
})
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'error'
|
||||
|
||||
response = await test_client.post('/api/auth/login', json={
|
||||
"username": core_lifecycle_td.astrbot_config['dashboard']['username'],
|
||||
"password": core_lifecycle_td.astrbot_config['dashboard']['password']
|
||||
})
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok' and 'token' in data['data']
|
||||
header['Authorization'] = f"Bearer {data['data']['token']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stat(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get('/api/stat/get')
|
||||
assert response.status_code == 401
|
||||
response = await test_client.get('/api/stat/get', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok' and 'platform' in data['data']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugins(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
# 已经安装的插件
|
||||
response = await test_client.get('/api/plugin/get', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件市场
|
||||
response = await test_client.get('/api/plugin/market_list', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件安装
|
||||
response = await test_client.post('/api/plugin/install', json={
|
||||
"url": "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
|
||||
# 插件更新
|
||||
response = await test_client.post('/api/plugin/update', json={
|
||||
"name": "astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
|
||||
# 插件卸载
|
||||
response = await test_client.post('/api/plugin/uninstall', json={
|
||||
"name": "astrbot_plugin_essential"
|
||||
}, headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_update(app: Quart, header: dict):
|
||||
test_client = app.test_client()
|
||||
response = await test_client.get('/api/update/check', headers=header)
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'success'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle):
|
||||
global VERSION
|
||||
test_client = app.test_client()
|
||||
os.makedirs("data/astrbot_release", exist_ok=True)
|
||||
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
|
||||
VERSION = "114.514.1919810"
|
||||
response = await test_client.post('/api/update/do', headers=header, json={
|
||||
"version": "latest"
|
||||
})
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'error' # 已经是最新版本
|
||||
|
||||
response = await test_client.post('/api/update/do', headers=header, json={
|
||||
"version": "v3.4.0",
|
||||
"reboot": False
|
||||
})
|
||||
assert response.status_code == 200
|
||||
data = await response.get_json()
|
||||
assert data['status'] == 'ok'
|
||||
assert os.path.exists("data/astrbot_release/astrbot")
|
||||
@@ -0,0 +1,48 @@
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from main import check_env, check_dashboard_files
|
||||
|
||||
class _version_info():
|
||||
def __init__(self, major, minor):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
|
||||
def test_check_env(monkeypatch):
|
||||
version_info_correct = _version_info(3, 10)
|
||||
version_info_wrong = _version_info(3, 9)
|
||||
monkeypatch.setattr(sys, 'version_info', version_info_correct)
|
||||
with mock.patch('os.makedirs') as mock_makedirs:
|
||||
check_env()
|
||||
mock_makedirs.assert_any_call("data/config", exist_ok=True)
|
||||
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
|
||||
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
|
||||
|
||||
monkeypatch.setattr(sys, 'version_info', version_info_wrong)
|
||||
with pytest.raises(SystemExit):
|
||||
check_env()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_dashboard_files(monkeypatch):
|
||||
monkeypatch.setattr(os.path, 'exists', lambda x: False)
|
||||
async def mock_get(*args, **kwargs):
|
||||
class MockResponse:
|
||||
status = 200
|
||||
async def read(self):
|
||||
return b'content'
|
||||
return MockResponse()
|
||||
|
||||
with mock.patch('aiohttp.ClientSession.get', new=mock_get):
|
||||
with mock.patch('builtins.open', mock.mock_open()) as mock_file:
|
||||
with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
|
||||
async def mock_aenter(_):
|
||||
await check_dashboard_files()
|
||||
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
|
||||
mock_extractall.assert_called_once()
|
||||
|
||||
async def mock_aexit(obj, exc_type, exc, tb):
|
||||
return
|
||||
|
||||
mock_extractall.__aenter__ = mock_aenter
|
||||
mock_extractall.__aexit__ = mock_aexit
|
||||
@@ -0,0 +1,226 @@
|
||||
import pytest
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core.message.components import Plain, At
|
||||
from astrbot.core.platform.platform_metadata import PlatformMetadata
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.star.context import Context
|
||||
from asyncio import Queue
|
||||
|
||||
SESSION_ID_IN_WHITELIST = "test_sid_wl"
|
||||
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
|
||||
TEST_LLM_PROVIDER = {
|
||||
"id": "zhipu_default",
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": [os.getenv("ZHIPU_API_KEY")],
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
}
|
||||
|
||||
TEST_COMMANDS = [
|
||||
["help", "已注册的 AstrBot 内置指令"],
|
||||
["tool ls", "函数工具"],
|
||||
["tool on websearch", "激活工具"],
|
||||
["tool off websearch", "停用工具"],
|
||||
["plugin", "已加载的插件"],
|
||||
["t2i", "文本转图片模式"],
|
||||
["sid", "此 ID 可用于设置会话白名单。"],
|
||||
["op test_op", "授权成功。"],
|
||||
["deop test_op", "取消授权成功。"],
|
||||
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
|
||||
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
|
||||
["provider", "当前载入的 LLM 提供商"],
|
||||
["reset", "重置成功"],
|
||||
# ["model", "查看、切换提供商模型列表"],
|
||||
["history", "历史记录:"],
|
||||
["key", "当前 Key"],
|
||||
["persona", "[Persona]"]
|
||||
]
|
||||
|
||||
class FakeAstrMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, abm: AstrBotMessage = None):
|
||||
meta = PlatformMetadata("test_platform", "test")
|
||||
super().__init__(
|
||||
message_str=abm.message_str,
|
||||
message_obj=abm,
|
||||
platform_meta=meta,
|
||||
session_id=abm.session_id
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_event(
|
||||
message_str: str,
|
||||
session_id: str = "test_sid",
|
||||
is_at: bool = False,
|
||||
is_group: bool = False,
|
||||
sender_id: str = "123456"
|
||||
):
|
||||
abm = AstrBotMessage()
|
||||
abm.message_str = message_str
|
||||
abm.group_id = "test"
|
||||
abm.message = [Plain(message_str)]
|
||||
if is_at:
|
||||
abm.message.append(At(qq="bot"))
|
||||
abm.self_id = "bot"
|
||||
abm.sender = MessageMember(sender_id, "mika")
|
||||
abm.timestamp = 1234567890
|
||||
abm.message_id = "test"
|
||||
abm.session_id = session_id
|
||||
if is_group:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return FakeAstrMessageEvent(abm)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_queue():
|
||||
return Queue()
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config():
|
||||
cfg = AstrBotConfig()
|
||||
cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
|
||||
cfg['admins_id'] = ["123456"]
|
||||
cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
|
||||
cfg['provider'] = [TEST_LLM_PROVIDER]
|
||||
return cfg
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def db():
|
||||
return SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def platform_manager(event_queue, config):
|
||||
return PlatformManager(config, event_queue)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def provider_manager(config, db):
|
||||
return ProviderManager(config, db)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def star_context(event_queue, config, db, platform_manager, provider_manager):
|
||||
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
|
||||
return star_context
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plugin_manager(star_context, config):
|
||||
plugin_manager = PluginManager(star_context, config)
|
||||
# await plugin_manager.reload()
|
||||
asyncio.run(plugin_manager.reload())
|
||||
return plugin_manager
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_context(config, plugin_manager):
|
||||
return PipelineContext(config, plugin_manager)
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def pipeline_scheduler(pipeline_context):
|
||||
return PipelineScheduler(pipeline_context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_platform_initialization(platform_manager: PlatformManager):
|
||||
await platform_manager.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_initialization(provider_manager: ProviderManager):
|
||||
await provider_manager.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
|
||||
await pipeline_scheduler.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
'''测试唤醒'''
|
||||
# 群聊无 @ 无指令
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
|
||||
# 群聊有 @ 无指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
|
||||
# 群聊有指令
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert mock_event._has_send_oper is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
# 测试默认屏蔽词
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
# 测试额外屏蔽词
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.INFO):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("内容安全检查不通过" not in message for message in caplog.messages)
|
||||
# TODO: 测试 百度AI 的内容安全检查
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert mock_event.get_result() is not None
|
||||
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
assert any("请求 LLM" in message for message in caplog.messages)
|
||||
assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
|
||||
for command in TEST_COMMANDS:
|
||||
caplog.clear()
|
||||
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
await pipeline_scheduler.execute(mock_event)
|
||||
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
|
||||
assert any(command[1] in message for message in caplog.messages)
|
||||
@@ -0,0 +1,80 @@
|
||||
import pytest
|
||||
import os
|
||||
from astrbot.core.star.star_manager import PluginManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from asyncio import Queue
|
||||
|
||||
event_queue = Queue()
|
||||
|
||||
config = AstrBotConfig()
|
||||
|
||||
db = SQLiteDatabase("data/data_v3.db")
|
||||
|
||||
star_context = Context(event_queue, config, db)
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_manager_pm():
|
||||
return PluginManager(star_context, config)
|
||||
|
||||
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
|
||||
assert plugin_manager_pm is not None
|
||||
assert plugin_manager_pm.context is not None
|
||||
assert plugin_manager_pm.config is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
|
||||
success, err_message = await plugin_manager_pm.reload()
|
||||
assert success is True
|
||||
assert err_message is None
|
||||
assert len(star_handlers_registry) > 0 # package
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plugin_crud(plugin_manager_pm: PluginManager):
|
||||
'''测试插件安装和重载'''
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert plugin_path is not None
|
||||
assert os.path.exists(plugin_path)
|
||||
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
|
||||
# shutil.rmtree(plugin_path)
|
||||
|
||||
# install plugin which is not exists
|
||||
with pytest.raises(Exception):
|
||||
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
|
||||
|
||||
# update
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# uninstall
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
|
||||
assert not os.path.exists(plugin_path)
|
||||
exists = False
|
||||
for md in star_registry:
|
||||
if md.name == "astrbot_plugin_essential":
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
exists = False
|
||||
for md in star_handlers_registry:
|
||||
if "astrbot_plugin_essential" in md.handler_module_path:
|
||||
exists = True
|
||||
break
|
||||
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
|
||||
|
||||
# TODO: file installation
|
||||
Reference in New Issue
Block a user